diff --git a/.github/workflows/pr_test_peft_backend.yml b/.github/workflows/pr_test_peft_backend.yml
index 26b4b5a3deec..190e5d26e6f3 100644
--- a/.github/workflows/pr_test_peft_backend.yml
+++ b/.github/workflows/pr_test_peft_backend.yml
@@ -92,12 +92,14 @@ jobs:
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
+ # TODO (sayakpaul, DN6): revisit `--no-deps`
if [ "${{ matrix.lib-versions }}" == "main" ]; then
- python -m pip install -U peft@git+https://github.com/huggingface/peft.git
- python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
+ python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
+ python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
+ pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
else
- python -m uv pip install -U peft transformers accelerate
+ python -m uv pip install -U peft --no-deps
+ python -m uv pip install -U transformers accelerate --no-deps
fi
- name: Environment
diff --git a/docker/diffusers-onnxruntime-cuda/Dockerfile b/docker/diffusers-onnxruntime-cuda/Dockerfile
index bd1d871033c9..6124172e109e 100644
--- a/docker/diffusers-onnxruntime-cuda/Dockerfile
+++ b/docker/diffusers-onnxruntime-cuda/Dockerfile
@@ -28,7 +28,7 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3.10 -m uv pip install --no-cache-dir \
- torch \
+ "torch<2.5.0" \
torchvision \
torchaudio \
"onnxruntime-gpu>=1.13.1" \
diff --git a/docker/diffusers-pytorch-compile-cuda/Dockerfile b/docker/diffusers-pytorch-compile-cuda/Dockerfile
index cb4a9c0f9896..9d7578f5a4dc 100644
--- a/docker/diffusers-pytorch-compile-cuda/Dockerfile
+++ b/docker/diffusers-pytorch-compile-cuda/Dockerfile
@@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3.10 -m uv pip install --no-cache-dir \
- torch \
+ "torch<2.5.0" \
torchvision \
torchaudio \
invisible_watermark && \
diff --git a/docker/diffusers-pytorch-cpu/Dockerfile b/docker/diffusers-pytorch-cpu/Dockerfile
index 8d98c52598d2..1b39e58ca273 100644
--- a/docker/diffusers-pytorch-cpu/Dockerfile
+++ b/docker/diffusers-pytorch-cpu/Dockerfile
@@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3.10 -m uv pip install --no-cache-dir \
- torch \
+ "torch<2.5.0" \
torchvision \
torchaudio \
invisible_watermark \
diff --git a/docker/diffusers-pytorch-cuda/Dockerfile b/docker/diffusers-pytorch-cuda/Dockerfile
index 695f5ed08dc5..7317ef642aa5 100644
--- a/docker/diffusers-pytorch-cuda/Dockerfile
+++ b/docker/diffusers-pytorch-cuda/Dockerfile
@@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3.10 -m uv pip install --no-cache-dir \
- torch \
+ "torch<2.5.0" \
torchvision \
torchaudio \
invisible_watermark && \
diff --git a/docker/diffusers-pytorch-xformers-cuda/Dockerfile b/docker/diffusers-pytorch-xformers-cuda/Dockerfile
index 1693eb293024..356445a6d173 100644
--- a/docker/diffusers-pytorch-xformers-cuda/Dockerfile
+++ b/docker/diffusers-pytorch-xformers-cuda/Dockerfile
@@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
python3.10 -m pip install --no-cache-dir \
- torch \
+ "torch<2.5.0" \
torchvision \
torchaudio \
invisible_watermark && \
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index e218e9878599..58218c0272bd 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -150,6 +150,12 @@
title: Reinforcement learning training with DDPO
title: Methods
title: Training
+- sections:
+ - local: quantization/overview
+ title: Getting Started
+ - local: quantization/bitsandbytes
+ title: bitsandbytes
+ title: Quantization Methods
- sections:
- local: optimization/fp16
title: Speed up inference
@@ -209,6 +215,8 @@
title: Logging
- local: api/outputs
title: Outputs
+ - local: api/quantization
+ title: Quantization
title: Main Classes
- isExpanded: false
sections:
diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md
index 4cde7a111ae6..f0f4fd37e6d5 100644
--- a/docs/source/en/api/pipelines/cogvideox.md
+++ b/docs/source/en/api/pipelines/cogvideox.md
@@ -36,6 +36,10 @@ There are two models available that can be used with the text-to-video and video
There is one model available that can be used with the image-to-video CogVideoX pipeline:
- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`.
+There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team):
+- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`.
+- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`.
+
## Inference
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
@@ -118,6 +122,12 @@ It is also worth noting that torchao quantization is fully compatible with [torc
- all
- __call__
+## CogVideoXFunControlPipeline
+
+[[autodoc]] CogVideoXFunControlPipeline
+ - all
+ - __call__
+
## CogVideoXPipelineOutput
[[autodoc]] pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput
diff --git a/docs/source/en/api/pipelines/controlnet_flux.md b/docs/source/en/api/pipelines/controlnet_flux.md
index f63885b4d42c..82454ae5e930 100644
--- a/docs/source/en/api/pipelines/controlnet_flux.md
+++ b/docs/source/en/api/pipelines/controlnet_flux.md
@@ -1,4 +1,4 @@
-
+
+# Quantization
+
+Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index).
+
+Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class.
+
+
+
+Learn how to quantize models in the [Quantization](../quantization/overview) guide.
+
+
+
+
+## BitsAndBytesConfig
+
+[[autodoc]] BitsAndBytesConfig
+
+## DiffusersQuantizer
+
+[[autodoc]] quantizers.base.DiffusersQuantizer
diff --git a/docs/source/en/quantization/bitsandbytes.md b/docs/source/en/quantization/bitsandbytes.md
new file mode 100644
index 000000000000..118511b75d50
--- /dev/null
+++ b/docs/source/en/quantization/bitsandbytes.md
@@ -0,0 +1,260 @@
+
+
+# bitsandbytes
+
+[bitsandbytes](https://huggingface.co/docs/bitsandbytes/index) is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance.
+
+4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs.
+
+
+To use bitsandbytes, make sure you have the following libraries installed:
+
+```bash
+pip install diffusers transformers accelerate bitsandbytes -U
+```
+
+Now you can quantize a model by passing a [`BitsAndBytesConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
+
+
+
+
+Quantizing a model in 8-bit halves the memory-usage:
+
+```py
+from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+
+model_8bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quantization_config
+)
+```
+
+By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:
+
+```py
+from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+
+model_8bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.float32
+)
+model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype
+```
+
+Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
+
+
+
+
+Quantizing a model in 4-bit reduces your memory-usage by 4x:
+
+```py
+from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(load_in_4bit=True)
+
+model_4bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quantization_config
+)
+```
+
+By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:
+
+```py
+from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(load_in_4bit=True)
+
+model_4bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+ torch_dtype=torch.float32
+)
+model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype
+```
+
+Call [`~ModelMixin.push_to_hub`] after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
+
+
+
+
+
+
+Training with 8-bit and 4-bit weights are only supported for training *extra* parameters.
+
+
+
+Check your memory footprint with the `get_memory_footprint` method:
+
+```py
+print(model.get_memory_footprint())
+```
+
+Quantized models can be loaded from the [`~ModelMixin.from_pretrained`] method without needing to specify the `quantization_config` parameters:
+
+```py
+from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(load_in_4bit=True)
+
+model_4bit = FluxTransformer2DModel.from_pretrained(
+ "hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
+)
+```
+
+## 8-bit (LLM.int8() algorithm)
+
+
+
+Learn more about the details of 8-bit quantization in this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration)!
+
+
+
+This section explores some of the specific features of 8-bit models, such as outlier thresholds and skipping module conversion.
+
+### Outlier threshold
+
+An "outlier" is a hidden state value greater than a certain threshold, and these values are computed in fp16. While the values are usually normally distributed ([-3.5, 3.5]), this distribution can be very different for large models ([-60, 6] or [6, 60]). 8-bit quantization works well for values ~5, but beyond that, there is a significant performance penalty. A good default threshold value is 6, but a lower threshold may be needed for more unstable models (small models or finetuning).
+
+To find the best threshold for your model, we recommend experimenting with the `llm_int8_threshold` parameter in [`BitsAndBytesConfig`]:
+
+```py
+from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(
+ load_in_8bit=True, llm_int8_threshold=10,
+)
+
+model_8bit = FluxTransformer2DModel.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+)
+```
+
+### Skip module conversion
+
+For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), the `proj_out` module can be skipped using the `llm_int8_skip_modules` parameter in [`BitsAndBytesConfig`]:
+
+```py
+from diffusers import SD3Transformer2DModel, BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(
+ load_in_8bit=True, llm_int8_skip_modules=["proj_out"],
+)
+
+model_8bit = SD3Transformer2DModel.from_pretrained(
+ "stabilityai/stable-diffusion-3-medium-diffusers",
+ subfolder="transformer",
+ quantization_config=quantization_config,
+)
+```
+
+
+## 4-bit (QLoRA algorithm)
+
+
+
+Learn more about its details in this [blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
+
+
+
+This section explores some of the specific features of 4-bit models, such as changing the compute data type, using the Normal Float 4 (NF4) data type, and using nested quantization.
+
+
+### Compute data type
+
+To speedup computation, you can change the data type from float32 (the default value) to bf16 using the `bnb_4bit_compute_dtype` parameter in [`BitsAndBytesConfig`]:
+
+```py
+import torch
+from diffusers import BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
+```
+
+### Normal Float 4 (NF4)
+
+NF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the `bnb_4bit_quant_type` parameter in the [`BitsAndBytesConfig`]:
+
+```py
+from diffusers import BitsAndBytesConfig
+
+nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+)
+
+model_nf4 = SD3Transformer2DModel.from_pretrained(
+ "stabilityai/stable-diffusion-3-medium-diffusers",
+ subfolder="transformer",
+ quantization_config=nf4_config,
+)
+```
+
+For inference, the `bnb_4bit_quant_type` does not have a huge impact on performance. However, to remain consistent with the model weights, you should use the `bnb_4bit_compute_dtype` and `torch_dtype` values.
+
+### Nested quantization
+
+Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter.
+
+```py
+from diffusers import BitsAndBytesConfig
+
+double_quant_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_use_double_quant=True,
+)
+
+double_quant_model = SD3Transformer2DModel.from_pretrained(
+ "stabilityai/stable-diffusion-3-medium-diffusers",
+ subfolder="transformer",
+ quantization_config=double_quant_config,
+)
+```
+
+## Dequantizing `bitsandbytes` models
+
+Once quantized, you can dequantize the model to the original precision but this might result in a small quality loss of the model. Make sure you have enough GPU RAM to fit the dequantized model.
+
+```python
+from diffusers import BitsAndBytesConfig
+
+double_quant_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_use_double_quant=True,
+)
+
+double_quant_model = SD3Transformer2DModel.from_pretrained(
+ "stabilityai/stable-diffusion-3-medium-diffusers",
+ subfolder="transformer",
+ quantization_config=double_quant_config,
+)
+model.dequantize()
+```
+
+## Resources
+
+* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4)
+* [Training](https://gist.github.com/sayakpaul/05afd428bc089b47af7c016e42004527)
\ No newline at end of file
diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md
new file mode 100644
index 000000000000..d8adbc85a259
--- /dev/null
+++ b/docs/source/en/quantization/overview.md
@@ -0,0 +1,35 @@
+
+
+# Quantization
+
+Quantization techniques focus on representing data with less information while also trying to not lose too much accuracy. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory-usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits.
+
+
+
+Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
+
+
+
+
+
+If you are new to the quantization field, we recommend you to check out these beginner-friendly courses about quantization in collaboration with DeepLearning.AI:
+
+* [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/)
+* [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth/)
+
+
+
+## When to use what?
+
+This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
\ No newline at end of file
diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md
new file mode 100644
index 000000000000..8817431bede5
--- /dev/null
+++ b/examples/advanced_diffusion_training/README_flux.md
@@ -0,0 +1,351 @@
+# Advanced diffusion training examples
+
+## Train Dreambooth LoRA with Flux.1 Dev
+> [!TIP]
+> 💡 This example follows some of the techniques and recommended practices covered in the community derived guide we made for SDXL training: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script).
+> As many of these are architecture agnostic & generally relevant to fine-tuning of diffusion models we suggest to take a look 🤗
+
+[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text-to-image models like flux, stable diffusion given just a few(3~5) images of a subject.
+
+LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
+In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
+- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
+- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
+- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
+[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
+the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
+
+The `train_dreambooth_lora_flux_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_flux.py`, with
+advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl),
+[ostris](https://x.com/ostrisai):[ai-toolkit](https://github.com/ostris/ai-toolkit), [bghira](https://github.com/bghira):[SimpleTuner](https://github.com/bghira/SimpleTuner), [Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️
+
+> [!NOTE]
+> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳
+> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora)
+
+## Running locally with PyTorch
+
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the `examples/advanced_diffusion_training` folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+
+### Target Modules
+When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
+More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
+applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
+the exact modules for LoRA training. Here are some examples of target modules you can provide:
+- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
+- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
+- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
+> [!NOTE]
+> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
+> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
+> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
+> [!NOTE]
+> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
+
+### Pivotal Tuning (and more)
+**Training with text encoder(s)**
+
+Alongside the Transformer, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization
+available with `train_dreambooth_lora_flux_advanced.py`, in the advanced script **pivotal tuning** is also supported.
+[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning -
+we insert new tokens into the text encoders of the model, instead of reusing existing ones.
+We then optimize the newly-inserted token embeddings to represent the new concept.
+
+To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`).
+Please keep the following points in mind:
+
+* Flux uses two text encoders - [CLIP](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder) & [T5](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder_2) , by default `--train_text_encoder_ti` performs pivotal tuning for the **CLIP** encoder only.
+To activate pivotal tuning for both encoders, add the flag `--enable_t5_ti`.
+* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.
+* **pure textual inversion** - to support the full range from pivotal tuning to textual inversion we introduce `--train_transformer_frac` which controls the amount of epochs the transformer LoRA layers are trained. By default, `--train_transformer_frac==1`, to trigger a textual inversion run set `--train_transformer_frac==0`. Values between 0 and 1 are supported as well, and we welcome the community to experiment w/ different settings and share the results!
+* **token initializer** - similar to the original textual inversion work, you can specify a concept of your choosing as the starting point for training. By default, when enabling `--train_text_encoder_ti`, the new inserted tokens are initialized randomly. You can specify a token in `--initializer_concept` such that the starting point for the trained embeddings will be the embeddings associated with your chosen `--initializer_concept`.
+
+## Training examples
+
+Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon.
+
+Let's first download it locally:
+
+```python
+from huggingface_hub import snapshot_download
+
+local_dir = "./3d_icon"
+snapshot_download(
+ "LinoyTsaban/3d_icon",
+ local_dir=local_dir, repo_type="dataset",
+ ignore_patterns=".gitattributes",
+)
+```
+
+Let's review some of the advanced features we're going to be using for this example:
+- **custom captions**:
+To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by
+```bash
+pip install datasets
+```
+
+Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt")
+
+```
+--dataset_name=./3d_icon
+--caption_column=prompt
+```
+
+You can also load a dataset straight from by specifying it's name in `dataset_name`.
+Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset.
+
+- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer
+- **pivotal tuning**
+
+### Example #1: Pivotal tuning
+**Now, we can launch training:**
+
+```bash
+export MODEL_NAME="black-forest-labs/FLUX.1-dev"
+export DATASET_NAME="./3d_icon"
+export OUTPUT_DIR="3d-icon-Flux-LoRA"
+
+accelerate launch train_dreambooth_lora_flux_advanced.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME \
+ --instance_prompt="3d icon in the style of TOK" \
+ --output_dir=$OUTPUT_DIR \
+ --caption_column="prompt" \
+ --mixed_precision="bf16" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --repeats=1 \
+ --report_to="wandb"\
+ --gradient_accumulation_steps=1 \
+ --gradient_checkpointing \
+ --learning_rate=1.0 \
+ --text_encoder_lr=1.0 \
+ --optimizer="prodigy"\
+ --train_text_encoder_ti\
+ --train_text_encoder_ti_frac=0.5\
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --rank=8 \
+ --max_train_steps=700 \
+ --checkpointing_steps=2000 \
+ --seed="0" \
+ --push_to_hub
+```
+
+To better track our training experiments, we're using the following flags in the command above:
+
+* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
+* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
+
+Our experiments were conducted on a single 40GB A100 GPU.
+
+### Example #2: Pivotal tuning with T5
+Now let's try that with T5 as well, so instead of only optimizing the CLIP embeddings associated with newly inserted tokens, we'll optimize
+the T5 embeddings as well. We can do this by simply adding `--enable_t5_ti` to the previous configuration:
+```bash
+export MODEL_NAME="black-forest-labs/FLUX.1-dev"
+export DATASET_NAME="./3d_icon"
+export OUTPUT_DIR="3d-icon-Flux-LoRA"
+
+accelerate launch train_dreambooth_lora_flux_advanced.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME \
+ --instance_prompt="3d icon in the style of TOK" \
+ --output_dir=$OUTPUT_DIR \
+ --caption_column="prompt" \
+ --mixed_precision="bf16" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --repeats=1 \
+ --report_to="wandb"\
+ --gradient_accumulation_steps=1 \
+ --gradient_checkpointing \
+ --learning_rate=1.0 \
+ --text_encoder_lr=1.0 \
+ --optimizer="prodigy"\
+ --train_text_encoder_ti\
+ --enable_t5_ti\
+ --train_text_encoder_ti_frac=0.5\
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --rank=8 \
+ --max_train_steps=700 \
+ --checkpointing_steps=2000 \
+ --seed="0" \
+ --push_to_hub
+```
+
+### Example #3: Textual Inversion
+To explore a pure textual inversion - i.e. only optimizing the text embeddings w/o training transformer LoRA layers, we
+can set the value for `--train_transformer_frac` - which is responsible for the percent of epochs in which the transformer is
+trained. By setting `--train_transformer_frac == 0` and enabling `--train_text_encoder_ti` we trigger a textual inversion train
+run.
+```bash
+export MODEL_NAME="black-forest-labs/FLUX.1-dev"
+export DATASET_NAME="./3d_icon"
+export OUTPUT_DIR="3d-icon-Flux-LoRA"
+
+accelerate launch train_dreambooth_lora_flux_advanced.py \
+ --pretrained_model_name_or_path=$MODEL_NAME \
+ --dataset_name=$DATASET_NAME \
+ --instance_prompt="3d icon in the style of TOK" \
+ --output_dir=$OUTPUT_DIR \
+ --caption_column="prompt" \
+ --mixed_precision="bf16" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --repeats=1 \
+ --report_to="wandb"\
+ --gradient_accumulation_steps=1 \
+ --gradient_checkpointing \
+ --learning_rate=1.0 \
+ --text_encoder_lr=1.0 \
+ --optimizer="prodigy"\
+ --train_text_encoder_ti\
+ --enable_t5_ti\
+ --train_text_encoder_ti_frac=0.5\
+ --train_transformer_frac=0\
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --rank=8 \
+ --max_train_steps=700 \
+ --checkpointing_steps=2000 \
+ --seed="0" \
+ --push_to_hub
+```
+### Inference - pivotal tuning
+
+Once training is done, we can perform inference like so:
+1. starting with loading the transformer lora weights
+```python
+import torch
+from huggingface_hub import hf_hub_download, upload_file
+from diffusers import AutoPipelineForText2Image
+from safetensors.torch import load_file
+
+username = "linoyts"
+repo_id = f"{username}/3d-icon-Flux-LoRA"
+
+pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')
+
+
+pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors")
+```
+2. now we load the pivotal tuning embeddings
+> [!NOTE] #1 if `--enable_t5_ti` wasn't passed, we only load the embeddings to the CLIP encoder.
+
+> [!NOTE] #2 the number of tokens (i.e. ,...,) is either determined by `--num_new_tokens_per_abstraction` or by `--initializer_concept`. Make sure to update inference code accordingly :)
+```python
+text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
+tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
+
+embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model")
+
+state_dict = load_file(embedding_path)
+# load embeddings of text_encoder 1 (CLIP ViT-L/14)
+pipe.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
+# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti`
+pipe.load_textual_inversion(state_dict["t5"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
+```
+
+3. let's generate images
+
+```python
+instance_token = ""
+prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
+
+image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
+image.save("llama.png")
+```
+
+### Inference - pure textual inversion
+In this case, we don't load transformer layers as before, since we only optimize the text embeddings. The output of a textual inversion train run is a
+`.safetensors` file containing the trained embeddings for the new tokens either for the CLIP encoder, or for both encoders (CLIP and T5)
+
+1. starting with loading the embeddings.
+💡note that here too, if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder
+
+```python
+import torch
+from huggingface_hub import hf_hub_download, upload_file
+from diffusers import AutoPipelineForText2Image
+from safetensors.torch import load_file
+
+username = "linoyts"
+repo_id = f"{username}/3d-icon-Flux-LoRA"
+
+pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')
+
+text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
+tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
+
+embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model")
+
+state_dict = load_file(embedding_path)
+# load embeddings of text_encoder 1 (CLIP ViT-L/14)
+pipe.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
+# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti`
+pipe.load_textual_inversion(state_dict["t5"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
+```
+2. let's generate images
+
+```python
+instance_token = ""
+prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
+
+image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
+image.save("llama.png")
+```
+
+### Comfy UI / AUTOMATIC1111 Inference
+The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats!
+
+**AUTOMATIC1111 / SD.Next** \
+In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time.
+- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory.
+- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory.
+
+You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls `. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`.
+
+**ComfyUI** \
+In ComfyUI we will load a LoRA and a textual embedding at the same time.
+- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
+- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/).
diff --git a/examples/advanced_diffusion_training/requirements_flux.txt b/examples/advanced_diffusion_training/requirements_flux.txt
new file mode 100644
index 000000000000..dbc124ff6526
--- /dev/null
+++ b/examples/advanced_diffusion_training/requirements_flux.txt
@@ -0,0 +1,8 @@
+accelerate>=0.31.0
+torchvision
+transformers>=4.41.2
+ftfy
+tensorboard
+Jinja2
+peft>=0.11.1
+sentencepiece
\ No newline at end of file
diff --git a/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py
new file mode 100644
index 000000000000..e29c99821303
--- /dev/null
+++ b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py
@@ -0,0 +1,283 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import safetensors
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate):
+ instance_data_dir = "docs/source/en/imgs"
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
+ script_path = "examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py"
+
+ def test_dreambooth_lora_flux(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_text_encoder_flux(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --train_text_encoder
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ starts_with_expected_prefix = all(
+ (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
+ )
+ self.assertTrue(starts_with_expected_prefix)
+
+ def test_dreambooth_lora_pivotal_tuning_flux_clip(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --train_text_encoder_ti
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+ # make sure embeddings were also saved
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # make sure the state_dict has the correct naming in the parameters.
+ textual_inversion_state_dict = safetensors.torch.load_file(
+ os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")
+ )
+ is_clip = all("clip_l" in k for k in textual_inversion_state_dict.keys())
+ self.assertTrue(is_clip)
+
+ # when performing pivotal tuning, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --train_text_encoder_ti
+ --enable_t5_ti
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+ # make sure embeddings were also saved
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # make sure the state_dict has the correct naming in the parameters.
+ textual_inversion_state_dict = safetensors.torch.load_file(
+ os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors")
+ )
+ is_te = all(("clip_l" in k or "t5" in k) for k in textual_inversion_state_dict.keys())
+ self.assertTrue(is_te)
+
+ # when performing pivotal tuning, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_latent_caching(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --cache_latents
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
+
+ resume_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
new file mode 100644
index 000000000000..e3e46ead8ee3
--- /dev/null
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
@@ -0,0 +1,2463 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import itertools
+import logging
+import math
+import os
+import random
+import re
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from safetensors.torch import save_file
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ FluxPipeline,
+ FluxTransformer2DModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ _set_state_dict_into_text_encoder,
+ cast_training_params,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+ free_memory,
+)
+from diffusers.utils import (
+ check_min_version,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.32.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ train_text_encoder=False,
+ train_text_encoder_ti=False,
+ enable_t5_ti=False,
+ pure_textual_inversion=False,
+ token_abstraction_dict=None,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+):
+ widget_dict = []
+ trigger_str = f"You should use {instance_prompt} to trigger the image generation."
+
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+ diffusers_load_lora = ""
+ diffusers_imports_pivotal = ""
+ diffusers_example_pivotal = ""
+ if not pure_textual_inversion:
+ diffusers_load_lora = (
+ f"""pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')"""
+ )
+ if train_text_encoder_ti:
+ embeddings_filename = f"{repo_folder}_emb"
+ ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt))
+ trigger_str = (
+ "To trigger image generation of trained concept(or concepts) replace each concept identifier "
+ "in you prompt with the new inserted tokens:\n"
+ )
+ diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
+ from safetensors.torch import load_file
+ """
+ if enable_t5_ti:
+ diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
+ state_dict = load_file(embedding_path)
+ pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
+ pipeline.load_textual_inversion(state_dict["t5"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
+ """
+ else:
+ diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
+ state_dict = load_file(embedding_path)
+ pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
+ """
+ if token_abstraction_dict:
+ for key, value in token_abstraction_dict.items():
+ tokens = "".join(value)
+ trigger_str += f"""
+ to trigger concept `{key}` → use `{tokens}` in your prompt \n
+ """
+
+ model_description = f"""
+# Flux DreamBooth LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth LoRA weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md).
+
+Was LoRA for the text encoder enabled? {train_text_encoder}.
+
+Pivotal tuning was enabled: {train_text_encoder_ti}.
+
+## Trigger words
+
+{trigger_str}
+
+## Download model
+
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+{diffusers_imports_pivotal}
+pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda')
+{diffusers_load_lora}
+{diffusers_example_pivotal}
+image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "flux",
+ "flux-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def load_text_encoders(class_one, class_two):
+ text_encoder_one = class_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = class_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ return text_encoder_one, text_encoder_two
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ torch_dtype,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ autocast_ctx = nullcontext()
+
+ with autocast_ctx:
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ free_memory()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "T5EncoderModel":
+ from transformers import T5EncoderModel
+
+ return T5EncoderModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--token_abstraction",
+ type=str,
+ default="TOK",
+ help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, "
+ "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. "
+ "'TOK,TOK2,TOK3' etc.",
+ )
+
+ parser.add_argument(
+ "--num_new_tokens_per_abstraction",
+ type=int,
+ default=None,
+ help="number of new tokens inserted to the tokenizers per token_abstraction identifier when "
+ "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
+ "tokens - ",
+ )
+ parser.add_argument(
+ "--initializer_concept",
+ type=str,
+ default=None,
+ help="the concept to use to initialize the new inserted tokens when training with "
+ "--train_text_encoder_ti = True. By default, new tokens () are initialized with random value. "
+ "Alternatively, you could specify a different word/words whos value will be used as the starting point for the new inserted tokens. "
+ "--num_new_tokens_per_abstraction is ignored when initializer_concept is provided",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=512,
+ help="Maximum sequence length to use with with the T5 text encoder",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="flux-dreambooth-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_text_encoder_ti",
+ action="store_true",
+ help=("Whether to use pivotal tuning / textual inversion"),
+ )
+ parser.add_argument(
+ "--enable_t5_ti",
+ action="store_true",
+ help=(
+ "Whether to use pivotal tuning / textual inversion for the T5 encoder as well (in addition to CLIP encoder)"
+ ),
+ )
+
+ parser.add_argument(
+ "--train_text_encoder_ti_frac",
+ type=float,
+ default=0.5,
+ help=("The percentage of epochs to perform textual inversion"),
+ )
+
+ parser.add_argument(
+ "--train_text_encoder_frac",
+ type=float,
+ default=1.0,
+ help=("The percentage of epochs to perform text encoder tuning"),
+ )
+ parser.add_argument(
+ "--train_transformer_frac",
+ type=float,
+ default=1.0,
+ help=("The percentage of epochs to perform transformer tuning"),
+ )
+
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=3.5,
+ help="the FLUX.1 dev variant is a guidance distilled model",
+ )
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument(
+ "--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for transformer params"
+ )
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+ parser.add_argument(
+ "--lora_layers",
+ type=str,
+ default=None,
+ help=(
+ "The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. "
+ 'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
+ ),
+ )
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cache_latents",
+ action="store_true",
+ default=False,
+ help="Cache the VAE latents",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ if args.train_text_encoder and args.train_text_encoder_ti:
+ raise ValueError(
+ "Specify only one of `--train_text_encoder` or `--train_text_encoder_ti. "
+ "For full LoRA text encoder training check --train_text_encoder, for textual "
+ "inversion training check `--train_text_encoder_ti`"
+ )
+ if args.train_transformer_frac < 1 and not args.train_text_encoder_ti:
+ raise ValueError(
+ "--train_transformer_frac must be == 1 if text_encoder training / textual inversion is not enabled."
+ )
+ if args.train_transformer_frac < 1 and args.train_text_encoder_ti_frac < 1:
+ raise ValueError(
+ "--train_transformer_frac and --train_text_encoder_ti_frac are identical and smaller than 1. "
+ "This contradicts with --max_train_steps, please specify different values or set both to 1."
+ )
+ if args.enable_t5_ti and not args.train_text_encoder_ti:
+ logger.warning("You need not use --enable_t5_ti without --train_text_encoder_ti.")
+
+ if args.train_text_encoder_ti and args.initializer_concept and args.num_new_tokens_per_abstraction:
+ logger.warning(
+ "When specifying --initializer_concept, the number of tokens per abstraction is detrimned "
+ "by the initializer token. --num_new_tokens_per_abstraction will be ignored"
+ )
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ if args.class_data_dir is not None:
+ logger.warning("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ logger.warning("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+# Modified from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py
+class TokenEmbeddingsHandler:
+ def __init__(self, text_encoders, tokenizers):
+ self.text_encoders = text_encoders
+ self.tokenizers = tokenizers
+
+ self.train_ids: Optional[torch.Tensor] = None
+ self.train_ids_t5: Optional[torch.Tensor] = None
+ self.inserting_toks: Optional[List[str]] = None
+ self.embeddings_settings = {}
+
+ def initialize_new_tokens(self, inserting_toks: List[str]):
+ idx = 0
+ for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders):
+ assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings."
+ assert all(
+ isinstance(tok, str) for tok in inserting_toks
+ ), "All elements in inserting_toks should be strings."
+
+ self.inserting_toks = inserting_toks
+ special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
+ tokenizer.add_special_tokens(special_tokens_dict)
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
+ text_encoder.resize_token_embeddings(len(tokenizer))
+
+ # Convert the token abstractions to ids
+ if idx == 0:
+ self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
+ else:
+ self.train_ids_t5 = tokenizer.convert_tokens_to_ids(self.inserting_toks)
+
+ # random initialization of new tokens
+ embeds = (
+ text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
+ )
+ std_token_embedding = embeds.weight.data.std()
+
+ logger.info(f"{idx} text encoder's std_token_embedding: {std_token_embedding}")
+
+ train_ids = self.train_ids if idx == 0 else self.train_ids_t5
+ # if initializer_concept are not provided, token embeddings are initialized randomly
+ if args.initializer_concept is None:
+ hidden_size = (
+ text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size
+ )
+ embeds.weight.data[train_ids] = (
+ torch.randn(len(train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype)
+ * std_token_embedding
+ )
+ else:
+ # Convert the initializer_token, placeholder_token to ids
+ initializer_token_ids = tokenizer.encode(args.initializer_concept, add_special_tokens=False)
+ for token_idx, token_id in enumerate(train_ids):
+ embeds.weight.data[token_id] = (embeds.weight.data)[
+ initializer_token_ids[token_idx % len(initializer_token_ids)]
+ ].clone()
+
+ self.embeddings_settings[f"original_embeddings_{idx}"] = embeds.weight.data.clone()
+ self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding
+
+ # makes sure we don't update any embedding weights besides the newly added token
+ index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
+ index_no_updates[train_ids] = False
+
+ self.embeddings_settings[f"index_no_updates_{idx}"] = index_no_updates
+
+ logger.info(self.embeddings_settings[f"index_no_updates_{idx}"].shape)
+
+ idx += 1
+
+ def save_embeddings(self, file_path: str):
+ assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
+ tensors = {}
+ # text_encoder_one, idx==0 - CLIP ViT-L/14, text_encoder_two, idx==1 - T5 xxl
+ idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
+ for idx, text_encoder in enumerate(self.text_encoders):
+ train_ids = self.train_ids if idx == 0 else self.train_ids_t5
+ embeds = (
+ text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
+ )
+ assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
+ new_token_embeddings = embeds.weight.data[train_ids]
+
+ # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0),
+ # Note: When loading with diffusers, any name can work - simply specify in inference
+ tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings
+ # tensors[f"text_encoders_{idx}"] = new_token_embeddings
+
+ save_file(tensors, file_path)
+
+ @property
+ def dtype(self):
+ return self.text_encoders[0].dtype
+
+ @property
+ def device(self):
+ return self.text_encoders[0].device
+
+ @torch.no_grad()
+ def retract_embeddings(self):
+ for idx, text_encoder in enumerate(self.text_encoders):
+ embeds = (
+ text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
+ )
+ index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
+ embeds.weight.data[index_no_updates] = (
+ self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
+ .to(device=text_encoder.device)
+ .to(dtype=text_encoder.dtype)
+ )
+
+ # for the parts that were updated, we need to normalize them
+ # to have the same std as before
+ std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"]
+
+ index_updates = ~index_no_updates
+ new_embeddings = embeds.weight.data[index_updates]
+ off_ratio = std_token_embedding / new_embeddings.std()
+
+ new_embeddings = new_embeddings * (off_ratio**0.1)
+ embeds.weight.data[index_updates] = new_embeddings
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ train_text_encoder_ti,
+ token_abstraction_dict=None, # token mapping for textual inversion
+ class_data_root=None,
+ class_num=None,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+ self.token_abstraction_dict = token_abstraction_dict
+ self.train_text_encoder_ti = train_text_encoder_ti
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+
+ self.pixel_values = []
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ for image in self.instance_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ image = train_transforms(image)
+ self.pixel_values.append(image)
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.pixel_values[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ if self.train_text_encoder_ti:
+ # replace instances of --token_abstraction in caption with the new tokens: "" etc.
+ for token_abs, token_replacement in self.token_abstraction_dict.items():
+ caption = caption.replace(token_abs, "".join(token_replacement))
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # the given instance prompt is used for all images
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=False):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ add_special_tokens=add_special_tokens,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+def _get_t5_prompt_embeds(
+ text_encoder,
+ tokenizer,
+ max_sequence_length=512,
+ prompt=None,
+ num_images_per_prompt=1,
+ device=None,
+ text_input_ids=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
+
+ dtype = text_encoder.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+
+def _get_clip_prompt_embeds(
+ text_encoder,
+ tokenizer,
+ prompt: str,
+ device=None,
+ text_input_ids=None,
+ num_images_per_prompt: int = 1,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+
+def encode_prompt(
+ text_encoders,
+ tokenizers,
+ prompt: str,
+ max_sequence_length,
+ device=None,
+ num_images_per_prompt: int = 1,
+ text_input_ids_list=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+ dtype = text_encoders[0].dtype
+
+ pooled_prompt_embeds = _get_clip_prompt_embeds(
+ text_encoder=text_encoders[0],
+ tokenizer=tokenizers[0],
+ prompt=prompt,
+ device=device if device is not None else text_encoders[0].device,
+ num_images_per_prompt=num_images_per_prompt,
+ text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None,
+ )
+
+ prompt_embeds = _get_t5_prompt_embeds(
+ text_encoder=text_encoders[1],
+ tokenizer=tokenizers[1],
+ max_sequence_length=max_sequence_length,
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device if device is not None else text_encoders[1].device,
+ text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None,
+ )
+
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+
+# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer:
+# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95
+class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ with torch.no_grad():
+ # create weights for timesteps
+ num_timesteps = 1000
+
+ # generate the multiplier based on cosmap loss weighing
+ # this is only used on linear timesteps for now
+
+ # cosine map weighing is higher in the middle and lower at the ends
+ # bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
+ # cosmap_weighing = 2 / (math.pi * bot)
+
+ # sigma sqrt weighing is significantly higher at the end and lower at the beginning
+ sigma_sqrt_weighing = (self.sigmas**-2.0).float()
+ # clip at 1e4 (1e6 is too high)
+ sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4)
+ # bring to a mean of 1
+ sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean()
+
+ # Create linear timesteps from 1000 to 0
+ timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu")
+
+ self.linear_timesteps = timesteps
+ # self.linear_timesteps_weights = cosmap_weighing
+ self.linear_timesteps_weights = sigma_sqrt_weighing
+
+ # self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
+ pass
+
+ def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
+ # Get the indices of the timesteps
+ step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
+
+ # Get the weights for the timesteps
+ weights = self.linear_timesteps_weights[step_indices].flatten()
+
+ return weights
+
+ def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
+ sigmas = self.sigmas.to(device=device, dtype=dtype)
+ schedule_timesteps = self.timesteps.to(device)
+ timesteps = timesteps.to(device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+
+ return sigma
+
+ def add_noise(
+ self,
+ original_samples: torch.Tensor,
+ noise: torch.Tensor,
+ timesteps: torch.Tensor,
+ ) -> torch.Tensor:
+ ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
+ ## Add noise according to flow matching.
+ ## zt = (1 - texp) * x + texp * z1
+
+ # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ # timestep needs to be in [0, 1], we store them in [0, 1000]
+ # noisy_sample = (1 - timestep) * latent + timestep * noise
+ t_01 = (timesteps / 1000).to(original_samples.device)
+ noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
+
+ # n_dim = original_samples.ndim
+ # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
+ # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
+ return noisy_model_input
+
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
+ return sample
+
+ def set_train_timesteps(self, num_timesteps, device, linear=False):
+ if linear:
+ timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
+ self.timesteps = timesteps
+ return timesteps
+ else:
+ # distribute them closer to center. Inference distributes them as a bias toward first
+ # Generate values from 0 to 1
+ t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
+
+ # Scale and reverse the values to go from 1000 to 0
+ timesteps = (1 - t) * 1000
+
+ # Sort the timesteps in descending order
+ timesteps, _ = torch.sort(timesteps, descending=True)
+
+ self.timesteps = timesteps.to(device=device)
+
+ return timesteps
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
+ torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = FluxPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ free_memory()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ model_id = args.hub_model_id or Path(args.output_dir).name
+ repo_id = None
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=model_id,
+ exist_ok=True,
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+ tokenizer_two = T5TokenizerFast.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ transformer = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
+ )
+
+ if args.train_text_encoder_ti:
+ # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK,
+ # TOK2" -> ["TOK", "TOK2"] etc.
+ token_abstraction_list = [place_holder.strip() for place_holder in re.split(r",\s*", args.token_abstraction)]
+ logger.info(f"list of token identifiers: {token_abstraction_list}")
+
+ if args.initializer_concept is None:
+ num_new_tokens_per_abstraction = (
+ 2 if args.num_new_tokens_per_abstraction is None else args.num_new_tokens_per_abstraction
+ )
+ # if args.initializer_concept is provided, we ignore args.num_new_tokens_per_abstraction
+ else:
+ token_ids = tokenizer_one.encode(args.initializer_concept, add_special_tokens=False)
+ num_new_tokens_per_abstraction = len(token_ids)
+ if args.enable_t5_ti:
+ token_ids_t5 = tokenizer_two.encode(args.initializer_concept, add_special_tokens=False)
+ num_new_tokens_per_abstraction = max(len(token_ids), len(token_ids_t5))
+ logger.info(
+ f"initializer_concept: {args.initializer_concept}, num_new_tokens_per_abstraction: {num_new_tokens_per_abstraction}"
+ )
+
+ token_abstraction_dict = {}
+ token_idx = 0
+ for i, token in enumerate(token_abstraction_list):
+ token_abstraction_dict[token] = [f"" for j in range(num_new_tokens_per_abstraction)]
+ token_idx += num_new_tokens_per_abstraction - 1
+
+ # replace instances of --token_abstraction in --instance_prompt with the new tokens: "" etc.
+ for token_abs, token_replacement in token_abstraction_dict.items():
+ new_instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement))
+ if args.instance_prompt == new_instance_prompt:
+ logger.warning(
+ "Note! the instance prompt provided in --instance_prompt does not include the token abstraction specified "
+ "--token_abstraction. This may lead to incorrect optimization of text embeddings during pivotal tuning"
+ )
+ args.instance_prompt = new_instance_prompt
+ if args.with_prior_preservation:
+ args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement))
+ if args.validation_prompt:
+ args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
+
+ # initialize the new tokens for textual inversion
+ text_encoders = [text_encoder_one, text_encoder_two] if args.enable_t5_ti else [text_encoder_one]
+ tokenizers = [tokenizer_one, tokenizer_two] if args.enable_t5_ti else [tokenizer_one]
+ embedding_handler = TokenEmbeddingsHandler(text_encoders, tokenizers)
+ inserting_toks = []
+ for new_tok in token_abstraction_dict.values():
+ inserting_toks.extend(new_tok)
+ embedding_handler.initialize_new_tokens(inserting_toks=inserting_toks)
+
+ # We only train the additional adapter LoRA layers
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ vae.to(accelerator.device, dtype=weight_dtype)
+ transformer.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+
+ if args.lora_layers is not None:
+ target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
+ else:
+ target_modules = [
+ "attn.to_k",
+ "attn.to_q",
+ "attn.to_v",
+ "attn.to_out.0",
+ "attn.add_k_proj",
+ "attn.add_q_proj",
+ "attn.add_v_proj",
+ "attn.to_add_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "ff_context.net.0.proj",
+ "ff_context.net.2",
+ ]
+ # now we will add new LoRA weights to the attention layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=target_modules,
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ if args.train_text_encoder:
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder_one.add_adapter(text_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers
+ text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ FluxPipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ )
+ if args.train_text_encoder_ti:
+ embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors")
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+ text_encoder_one_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict = FluxPipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+ if args.train_text_encoder:
+ # Do we need to call `scale_lora_layers()` here?
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one_])
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one])
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+
+ if args.train_text_encoder:
+ text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
+ # if we use textual inversion, we freeze all parameters except for the token embeddings
+ # in text encoder
+ elif args.train_text_encoder_ti:
+ text_lora_parameters_one = [] # CLIP
+ for name, param in text_encoder_one.named_parameters():
+ if "token_embedding" in name:
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ param.data = param.to(dtype=torch.float32)
+ param.requires_grad = True
+ text_lora_parameters_one.append(param)
+ else:
+ param.requires_grad = False
+ if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well
+ text_lora_parameters_two = []
+ for name, param in text_encoder_two.named_parameters():
+ if "token_embedding" in name:
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ param.data = param.to(dtype=torch.float32)
+ param.requires_grad = True
+ text_lora_parameters_two.append(param)
+ else:
+ param.requires_grad = False
+
+ # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training
+ freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti)
+
+ # if --train_text_encoder_ti and train_transformer_frac == 0 where essentially performing textual inversion
+ # and not training transformer LoRA layers
+ pure_textual_inversion = args.train_text_encoder_ti and args.train_transformer_frac == 0
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ if not freeze_text_encoder:
+ # different learning rate for text encoder and transformer
+ text_parameters_one_with_lr = {
+ "params": text_lora_parameters_one,
+ "weight_decay": args.adam_weight_decay_text_encoder
+ if args.adam_weight_decay_text_encoder
+ else args.adam_weight_decay,
+ "lr": args.text_encoder_lr,
+ }
+ if not args.enable_t5_ti:
+ # pure textual inversion - only clip
+ if pure_textual_inversion:
+ params_to_optimize = [
+ text_parameters_one_with_lr,
+ ]
+ te_idx = 0
+ else: # regular te training or regular pivotal for clip
+ params_to_optimize = [
+ transformer_parameters_with_lr,
+ text_parameters_one_with_lr,
+ ]
+ te_idx = 1
+ elif args.enable_t5_ti:
+ # pivotal tuning of clip & t5
+ text_parameters_two_with_lr = {
+ "params": text_lora_parameters_two,
+ "weight_decay": args.adam_weight_decay_text_encoder
+ if args.adam_weight_decay_text_encoder
+ else args.adam_weight_decay,
+ "lr": args.text_encoder_lr,
+ }
+ # pure textual inversion - only clip & t5
+ if pure_textual_inversion:
+ params_to_optimize = [text_parameters_one_with_lr, text_parameters_two_with_lr]
+ te_idx = 0
+ else: # regular pivotal tuning of clip & t5
+ params_to_optimize = [
+ transformer_parameters_with_lr,
+ text_parameters_one_with_lr,
+ text_parameters_two_with_lr,
+ ]
+ te_idx = 1
+ else:
+ params_to_optimize = [
+ transformer_parameters_with_lr,
+ ]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+ if not freeze_text_encoder and args.text_encoder_lr:
+ logger.warning(
+ f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
+ f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
+ f"When using prodigy only learning_rate is used as the initial learning rate."
+ )
+ # changes the learning rate of text_encoder_parameters to be
+ # --learning_rate
+
+ params_to_optimize[te_idx]["lr"] = args.learning_rate
+ params_to_optimize[-1]["lr"] = args.learning_rate
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ train_text_encoder_ti=args.train_text_encoder_ti,
+ token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ if freeze_text_encoder:
+ tokenizers = [tokenizer_one, tokenizer_two]
+ text_encoders = [text_encoder_one, text_encoder_two]
+
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
+ text_encoders, tokenizers, prompt, args.max_sequence_length
+ )
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ text_ids = text_ids.to(accelerator.device)
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if freeze_text_encoder and not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings(
+ args.instance_prompt, text_encoders, tokenizers
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ if freeze_text_encoder:
+ class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings(
+ args.class_prompt, text_encoders, tokenizers
+ )
+
+ # Clear the memory here
+ if freeze_text_encoder and not train_dataset.custom_instance_prompts:
+ del tokenizers, text_encoders, text_encoder_one, text_encoder_two
+ free_memory()
+
+ # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion
+ add_special_tokens_clip = True if args.train_text_encoder_ti else False
+ add_special_tokens_t5 = True if (args.train_text_encoder_ti and args.enable_t5_ti) else False
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+
+ if not train_dataset.custom_instance_prompts:
+ if freeze_text_encoder:
+ prompt_embeds = instance_prompt_hidden_states
+ pooled_prompt_embeds = instance_pooled_prompt_embeds
+ text_ids = instance_text_ids
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
+ text_ids = torch.cat([text_ids, class_text_ids], dim=0)
+ # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts)
+ # we need to tokenize and encode the batch prompts on all training steps
+ else:
+ tokens_one = tokenize_prompt(
+ tokenizer_one, args.instance_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens_clip
+ )
+ tokens_two = tokenize_prompt(
+ tokenizer_two,
+ args.instance_prompt,
+ max_sequence_length=args.max_sequence_length,
+ add_special_tokens=add_special_tokens_t5,
+ )
+ if args.with_prior_preservation:
+ class_tokens_one = tokenize_prompt(
+ tokenizer_one,
+ args.class_prompt,
+ max_sequence_length=77,
+ add_special_tokens=add_special_tokens_clip,
+ )
+ class_tokens_two = tokenize_prompt(
+ tokenizer_two,
+ args.class_prompt,
+ max_sequence_length=args.max_sequence_length,
+ add_special_tokens=add_special_tokens_t5,
+ )
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+
+ vae_config_shift_factor = vae.config.shift_factor
+ vae_config_scaling_factor = vae.config.scaling_factor
+ vae_config_block_out_channels = vae.config.block_out_channels
+ if args.cache_latents:
+ latents_cache = []
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
+ with torch.no_grad():
+ batch["pixel_values"] = batch["pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=weight_dtype
+ )
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+
+ if args.validation_prompt is None:
+ del vae
+ free_memory()
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if not freeze_text_encoder:
+ if args.enable_t5_ti:
+ (
+ transformer,
+ text_encoder_one,
+ text_encoder_two,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ ) = accelerator.prepare(
+ transformer,
+ text_encoder_one,
+ text_encoder_two,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ )
+ else:
+ transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
+ )
+
+ else:
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-flux-dev-lora-advanced"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.")
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ logger.info(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ if args.train_text_encoder:
+ num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs)
+ num_train_epochs_transformer = int(args.train_transformer_frac * args.num_train_epochs)
+ elif args.train_text_encoder_ti: # args.train_text_encoder_ti
+ num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs)
+ num_train_epochs_transformer = int(args.train_transformer_frac * args.num_train_epochs)
+
+ # flag used for textual inversion
+ pivoted_te = False
+ pivoted_tr = False
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+ # if performing any kind of optimization of text_encoder params
+ if args.train_text_encoder or args.train_text_encoder_ti:
+ if epoch == num_train_epochs_text_encoder:
+ # flag to stop text encoder optimization
+ logger.info(f"PIVOT TE {epoch}")
+ pivoted_te = True
+ else:
+ # still optimizing the text encoder
+ if args.train_text_encoder:
+ text_encoder_one.train()
+ # set top parameter requires_grad = True for gradient checkpointing works
+ accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
+ elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
+ text_encoder_one.train()
+ if args.enable_t5_ti:
+ text_encoder_two.train()
+
+ if epoch == num_train_epochs_transformer:
+ # flag to stop transformer optimization
+ logger.info(f"PIVOT TRANSFORMER {epoch}")
+ pivoted_tr = True
+
+ for step, batch in enumerate(train_dataloader):
+ if pivoted_te:
+ # stopping optimization of text_encoder params
+ optimizer.param_groups[te_idx]["lr"] = 0.0
+ optimizer.param_groups[-1]["lr"] = 0.0
+ elif pivoted_tr and not pure_textual_inversion:
+ logger.info(f"PIVOT TRANSFORMER {epoch}")
+ optimizer.param_groups[0]["lr"] = 0.0
+
+ with accelerator.accumulate(transformer):
+ prompts = batch["prompts"]
+
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ if freeze_text_encoder:
+ prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
+ prompts, text_encoders, tokenizers
+ )
+ else:
+ tokens_one = tokenize_prompt(
+ tokenizer_one, prompts, max_sequence_length=77, add_special_tokens=add_special_tokens_clip
+ )
+ tokens_two = tokenize_prompt(
+ tokenizer_two,
+ prompts,
+ max_sequence_length=args.max_sequence_length,
+ add_special_tokens=add_special_tokens_t5,
+ )
+
+ if not freeze_text_encoder:
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=[None, None],
+ text_input_ids_list=[tokens_one, tokens_two],
+ max_sequence_length=args.max_sequence_length,
+ device=accelerator.device,
+ prompt=prompts,
+ )
+
+ # Convert images to latent space
+ if args.cache_latents:
+ model_input = latents_cache[step].sample()
+ else:
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
+ model_input = model_input.to(dtype=weight_dtype)
+
+ vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
+
+ latent_image_ids = FluxPipeline._prepare_latent_image_ids(
+ model_input.shape[0],
+ model_input.shape[2],
+ model_input.shape[3],
+ accelerator.device,
+ weight_dtype,
+ )
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+
+ packed_noisy_model_input = FluxPipeline._pack_latents(
+ noisy_model_input,
+ batch_size=model_input.shape[0],
+ num_channels_latents=model_input.shape[1],
+ height=model_input.shape[2],
+ width=model_input.shape[3],
+ )
+
+ # handle guidance
+ if transformer.config.guidance_embeds:
+ guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
+ guidance = guidance.expand(model_input.shape[0])
+ else:
+ guidance = None
+
+ # Predict the noise residual
+ model_pred = transformer(
+ hidden_states=packed_noisy_model_input,
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
+ timestep=timesteps / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ return_dict=False,
+ )[0]
+ model_pred = FluxPipeline._unpack_latents(
+ model_pred,
+ height=int(model_input.shape[2] * vae_scale_factor / 2),
+ width=int(model_input.shape[3] * vae_scale_factor / 2),
+ vae_scale_factor=vae_scale_factor,
+ )
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ target = noise - model_input
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = torch.mean(
+ (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
+ target_prior.shape[0], -1
+ ),
+ 1,
+ )
+ prior_loss = prior_loss.mean()
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ if not freeze_text_encoder:
+ if args.train_text_encoder:
+ params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters())
+ elif pure_textual_inversion:
+ params_to_clip = itertools.chain(
+ text_encoder_one.parameters(), text_encoder_two.parameters()
+ )
+ else:
+ params_to_clip = itertools.chain(
+ transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters()
+ )
+ else:
+ params_to_clip = itertools.chain(transformer.parameters())
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # every step, we reset the embeddings to the original embeddings.
+ if args.train_text_encoder_ti:
+ embedding_handler.retract_embeddings()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ if freeze_text_encoder:
+ text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
+ pipeline = FluxPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
+ transformer=accelerator.unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline_args = {"prompt": args.validation_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ torch_dtype=weight_dtype,
+ )
+ if freeze_text_encoder:
+ del text_encoder_one, text_encoder_two
+ free_memory()
+ elif args.train_text_encoder:
+ del text_encoder_two
+ free_memory()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ transformer = unwrap_model(transformer)
+ transformer = transformer.to(weight_dtype)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
+ if args.train_text_encoder:
+ text_encoder_one = unwrap_model(text_encoder_one)
+ text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
+ else:
+ text_encoder_lora_layers = None
+
+ if not pure_textual_inversion:
+ FluxPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ )
+
+ if args.train_text_encoder_ti:
+ embeddings_path = f"{args.output_dir}/{os.path.basename(args.output_dir)}_emb.safetensors"
+ embedding_handler.save_embeddings(embeddings_path)
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = FluxPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ if not pure_textual_inversion:
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {"prompt": args.validation_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ torch_dtype=weight_dtype,
+ is_final_validation=True,
+ )
+
+ save_model_card(
+ model_id if not args.push_to_hub else repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ train_text_encoder_ti=args.train_text_encoder_ti,
+ enable_t5_ti=args.enable_t5_ti,
+ pure_textual_inversion=pure_textual_inversion,
+ token_abstraction_dict=train_dataset.token_abstraction_dict,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ )
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
index 7e1a0298ba1d..024722536d88 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
@@ -71,7 +71,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
index 5222c8afe6f1..bc06cc9213dc 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
@@ -79,7 +79,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
index 0fdca2850784..4ef392baa2b5 100644
--- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py
+++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
@@ -61,7 +61,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py
index ece2228147e2..011466bc7d58 100644
--- a/examples/cogvideo/train_cogvideox_lora.py
+++ b/examples/cogvideo/train_cogvideox_lora.py
@@ -52,7 +52,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/community/README.md b/examples/community/README.md
index 267c8f4bb904..4f16f65df8fa 100755
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -4336,19 +4336,19 @@ The Abstract of the paper:
**64x64**
:-------------------------:
-| |
+| |
- `256×256, nesting_level=1`: 1.776 GiB. With `150` DDIM inference steps:
**64x64** | **256x256**
:-------------------------:|:-------------------------:
-| | |
+| | |
-- `1024×1024, nesting_level=2`: 1.792 GiB. As one can realize the cost of adding another layer is really negligible. With `250` DDIM inference steps:
+- `1024×1024, nesting_level=2`: 1.792 GiB. As one can realize the cost of adding another layer is really negligible in this context! With `250` DDIM inference steps:
**64x64** | **256x256** | **1024x1024**
:-------------------------:|:-------------------------:|:-------------------------:
-| | | |
+| | | |
```py
from diffusers import DiffusionPipeline
@@ -4362,8 +4362,7 @@ pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-model
prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree"
prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed"
-negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy"
-image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images
+image = pipe(prompt, num_inference_steps=50).images
make_image_grid(image, rows=1, cols=len(image))
# pipe.change_nesting_level() # 0, 1, or 2
diff --git a/examples/community/README_community_scripts.md b/examples/community/README_community_scripts.md
index 8432b4e82c9f..2c2f549a2bd5 100644
--- a/examples/community/README_community_scripts.md
+++ b/examples/community/README_community_scripts.md
@@ -8,6 +8,7 @@ If a community script doesn't work as expected, please open an issue and ping th
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
| Using IP-Adapter with negative noise | Using negative noise with IP-adapter to better control the generation (see the [original post](https://github.com/huggingface/diffusers/discussions/7167) on the forum for more details) | [IP-Adapter Negative Noise](#ip-adapter-negative-noise) | | [Álvaro Somoza](https://github.com/asomoza)|
| asymmetric tiling |configure seamless image tiling independently for the X and Y axes | [Asymmetric Tiling](#asymmetric-tiling ) | | [alexisrolland](https://github.com/alexisrolland)|
+| Prompt scheduling callback |Allows changing prompts during a generation | [Prompt Scheduling](#prompt-scheduling ) | | [hlky](https://github.com/hlky)|
## Example usages
@@ -229,4 +230,86 @@ seamless_tiling(pipeline=pipeline, x_axis=False, y_axis=False)
torch.cuda.empty_cache()
image.save('image.png')
-```
\ No newline at end of file
+```
+
+### Prompt Scheduling callback
+
+Prompt scheduling callback allows changing prompts during a generation, like [prompt editing in A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-editing)
+
+```python
+from diffusers import StableDiffusionPipeline
+from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
+from diffusers.configuration_utils import register_to_config
+import torch
+from typing import Any, Dict, Optional
+
+
+pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ variant="fp16",
+ use_safetensors=True,
+).to("cuda")
+pipeline.safety_checker = None
+pipeline.requires_safety_checker = False
+
+
+class SDPromptScheduleCallback(PipelineCallback):
+ @register_to_config
+ def __init__(
+ self,
+ prompt: str,
+ negative_prompt: Optional[str] = None,
+ num_images_per_prompt: int = 1,
+ cutoff_step_ratio=1.0,
+ cutoff_step_index=None,
+ ):
+ super().__init__(
+ cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index
+ )
+
+ tensor_inputs = ["prompt_embeds"]
+
+ def callback_fn(
+ self, pipeline, step_index, timestep, callback_kwargs
+ ) -> Dict[str, Any]:
+ cutoff_step_ratio = self.config.cutoff_step_ratio
+ cutoff_step_index = self.config.cutoff_step_index
+
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
+ cutoff_step = (
+ cutoff_step_index
+ if cutoff_step_index is not None
+ else int(pipeline.num_timesteps * cutoff_step_ratio)
+ )
+
+ if step_index == cutoff_step:
+ prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
+ prompt=self.config.prompt,
+ negative_prompt=self.config.negative_prompt,
+ device=pipeline._execution_device,
+ num_images_per_prompt=self.config.num_images_per_prompt,
+ do_classifier_free_guidance=pipeline.do_classifier_free_guidance,
+ )
+ if pipeline.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
+ return callback_kwargs
+
+callback = MultiPipelineCallbacks(
+ [
+ SDPromptScheduleCallback(
+ prompt="Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
+ negative_prompt="Deformed, ugly, bad anatomy",
+ cutoff_step_ratio=0.25,
+ )
+ ]
+)
+
+image = pipeline(
+ prompt="Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
+ negative_prompt="Deformed, ugly, bad anatomy",
+ callback_on_step_end=callback,
+ callback_on_step_end_tensor_inputs=["prompt_embeds"],
+).images[0]
+```
diff --git a/examples/community/marigold_depth_estimation.py b/examples/community/marigold_depth_estimation.py
index 92f01d046ef9..a8f406309a52 100644
--- a/examples/community/marigold_depth_estimation.py
+++ b/examples/community/marigold_depth_estimation.py
@@ -43,7 +43,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
class MarigoldDepthOutput(BaseOutput):
diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py
index 7ef1438f7204..7ac0ab542910 100644
--- a/examples/community/matryoshka.py
+++ b/examples/community/matryoshka.py
@@ -107,15 +107,16 @@
>>> # nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64
>>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
- >>> custom_pipeline="matryoshka").to("cuda")
+ ... nesting_level=0,
+ ... trust_remote_code=False, # One needs to give permission for this code to run
+ ... ).to("cuda")
>>> prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree"
>>> prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed"
- >>> negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy"
- >>> image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images
+ >>> image = pipe(prompt, num_inference_steps=50).images
>>> make_image_grid(image, rows=1, cols=len(image))
- >>> pipe.change_nesting_level() # 0, 1, or 2
+ >>> # pipe.change_nesting_level() # 0, 1, or 2
>>> # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
```
"""
@@ -420,6 +421,7 @@ def __init__(
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
self.scales = None
+ self.schedule_shifted_power = 1.0
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
@@ -532,6 +534,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
def get_schedule_shifted(self, alpha_prod, scale_factor=None):
if (scale_factor is not None) and (scale_factor > 1): # rescale noise schedule
+ scale_factor = scale_factor**self.schedule_shifted_power
snr = alpha_prod / (1 - alpha_prod)
scaled_snr = snr / scale_factor
alpha_prod = 1 / (1 + 1 / scaled_snr)
@@ -639,17 +642,14 @@ def step(
# 4. Clip or threshold "predicted x_0"
if self.config.thresholding:
if len(model_output) > 1:
- pred_original_sample = [
- self._threshold_sample(p_o_s * scale) / scale
- for p_o_s, scale in zip(pred_original_sample, self.scales)
- ]
+ pred_original_sample = [self._threshold_sample(p_o_s) for p_o_s in pred_original_sample]
else:
pred_original_sample = self._threshold_sample(pred_original_sample)
elif self.config.clip_sample:
if len(model_output) > 1:
pred_original_sample = [
- (p_o_s * scale).clamp(-self.config.clip_sample_range, self.config.clip_sample_range) / scale
- for p_o_s, scale in zip(pred_original_sample, self.scales)
+ p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range)
+ for p_o_s in pred_original_sample
]
else:
pred_original_sample = pred_original_sample.clamp(
@@ -3816,6 +3816,8 @@ def __init__(
if hasattr(unet, "nest_ratio"):
scheduler.scales = unet.nest_ratio + [1]
+ if nesting_level == 2:
+ scheduler.schedule_shifted_power = 2.0
self.register_modules(
text_encoder=text_encoder,
@@ -3842,12 +3844,14 @@ def change_nesting_level(self, nesting_level: int):
).to(self.device)
self.config.nesting_level = 1
self.scheduler.scales = self.unet.nest_ratio + [1]
+ self.scheduler.schedule_shifted_power = 1.0
elif nesting_level == 2:
self.unet = NestedUNet2DConditionModel.from_pretrained(
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2"
).to(self.device)
self.config.nesting_level = 2
self.scheduler.scales = self.unet.nest_ratio + [1]
+ self.scheduler.schedule_shifted_power = 2.0
else:
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
@@ -4627,8 +4631,8 @@ def __call__(
image = latents
if self.scheduler.scales is not None:
- for i, (img, scale) in enumerate(zip(image, self.scheduler.scales)):
- image[i] = self.image_processor.postprocess(img * scale, output_type=output_type)[0]
+ for i, img in enumerate(image):
+ image[i] = self.image_processor.postprocess(img, output_type=output_type)[0]
else:
image = self.image_processor.postprocess(image, output_type=output_type)
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
index 611026675daf..0750df79eb0d 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
@@ -73,7 +73,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
index 8090926974c4..493742691286 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
@@ -66,7 +66,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
index fa7e7f1febee..824f148c58fd 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
@@ -79,7 +79,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py
index 12d7db09a361..a334c27e7d86 100644
--- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py
@@ -72,7 +72,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
index cc5e6812127e..6e5e85172f14 100644
--- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
@@ -78,7 +78,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/controlnet/README_sd3.md b/examples/controlnet/README_sd3.md
index 1788e07a21d6..7a7b4841125f 100644
--- a/examples/controlnet/README_sd3.md
+++ b/examples/controlnet/README_sd3.md
@@ -104,7 +104,7 @@ from diffusers.utils import load_image
import torch
base_model_path = "stabilityai/stable-diffusion-3-medium-diffusers"
-controlnet_path = "sd3-controlnet-out/checkpoint-6500/controlnet"
+controlnet_path = "DavyMorgan/sd3-controlnet-out"
controlnet = SD3ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py
index 2e902db7ffc7..a2aa266cdfbc 100644
--- a/examples/controlnet/train_controlnet.py
+++ b/examples/controlnet/train_controlnet.py
@@ -60,7 +60,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
@@ -1048,7 +1048,9 @@ def load_model_hook(models, input_dir):
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+ noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to(
+ dtype=weight_dtype
+ )
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py
index 1aa9e881fca5..44c286cd2a40 100644
--- a/examples/controlnet/train_controlnet_flax.py
+++ b/examples/controlnet/train_controlnet_flax.py
@@ -60,7 +60,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py
index 5969218f3c3e..ca822b16eae2 100644
--- a/examples/controlnet/train_controlnet_flux.py
+++ b/examples/controlnet/train_controlnet_flux.py
@@ -65,7 +65,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py
index 4fae8a072c6f..2bb68220e268 100644
--- a/examples/controlnet/train_controlnet_sd3.py
+++ b/examples/controlnet/train_controlnet_sd3.py
@@ -50,7 +50,7 @@
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
-from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
@@ -59,22 +59,11 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.30.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
-def image_grid(imgs, rows, cols):
- assert len(imgs) == rows * cols
-
- w, h = imgs[0].size
- grid = Image.new("RGB", size=(cols * w, rows * h))
-
- for i, img in enumerate(imgs):
- grid.paste(img, box=(i % cols * w, i // cols * h))
- return grid
-
-
def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
logger.info("Running validation... ")
@@ -224,7 +213,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
validation_image.save(os.path.join(repo_folder, "image_control.png"))
img_str += f"prompt: {validation_prompt}\n"
images = [validation_image] + images
- image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
+ make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
img_str += f"![images_{i})](./images_{i}.png)\n"
model_description = f"""
diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py
index 877ca6135849..c034c027cbcd 100644
--- a/examples/controlnet/train_controlnet_sdxl.py
+++ b/examples/controlnet/train_controlnet_sdxl.py
@@ -61,7 +61,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -1210,7 +1210,9 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+ noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to(
+ dtype=weight_dtype
+ )
# ControlNet conditioning.
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py
index e498ca98b1c7..151817247350 100644
--- a/examples/custom_diffusion/train_custom_diffusion.py
+++ b/examples/custom_diffusion/train_custom_diffusion.py
@@ -63,7 +63,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py
index 5099107118e4..4b614807cfc4 100644
--- a/examples/dreambooth/train_dreambooth.py
+++ b/examples/dreambooth/train_dreambooth.py
@@ -63,7 +63,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py
index 29fd5e78535d..3023b28aca7f 100644
--- a/examples/dreambooth/train_dreambooth_flax.py
+++ b/examples/dreambooth/train_dreambooth_flax.py
@@ -35,7 +35,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py
index 8e0f4e09a461..db4788281cf2 100644
--- a/examples/dreambooth/train_dreambooth_flux.py
+++ b/examples/dreambooth/train_dreambooth_flux.py
@@ -64,7 +64,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py
index 5d7d697bb21d..bf778693a88d 100644
--- a/examples/dreambooth/train_dreambooth_lora.py
+++ b/examples/dreambooth/train_dreambooth_lora.py
@@ -70,7 +70,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py
index 11cba745cc4a..b09e5b38b2b1 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux.py
@@ -72,7 +72,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py
index 8d0b6853eeec..4b39dcfe41b0 100644
--- a/examples/dreambooth/train_dreambooth_lora_sd3.py
+++ b/examples/dreambooth/train_dreambooth_lora_sd3.py
@@ -72,7 +72,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
@@ -86,6 +86,15 @@ def save_model_card(
validation_prompt=None,
repo_folder=None,
):
+ if "large" in base_model:
+ model_variant = "SD3.5-Large"
+ license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
+ variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"]
+ else:
+ model_variant = "SD3"
+ license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
+ variant_tags = ["sd3", "sd3-diffusers"]
+
widget_dict = []
if images is not None:
for i, image in enumerate(images):
@@ -95,7 +104,7 @@ def save_model_card(
)
model_description = f"""
-# SD3 DreamBooth LoRA - {repo_id}
+# {model_variant} DreamBooth LoRA - {repo_id}
@@ -120,7 +129,7 @@ def save_model_card(
```py
from diffusers import AutoPipelineForText2Image
import torch
-pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers', torch_dtype=torch.float16).to('cuda')
+pipeline = AutoPipelineForText2Image.from_pretrained({base_model}, torch_dtype=torch.float16).to('cuda')
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
```
@@ -135,7 +144,7 @@ def save_model_card(
## License
-Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE).
+Please adhere to the licensing terms as described [here]({license_url}).
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
@@ -151,11 +160,11 @@ def save_model_card(
"diffusers-training",
"diffusers",
"lora",
- "sd3",
- "sd3-diffusers",
"template:sd-lora",
]
+ tags += variant_tags
+
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 016464165c44..bf8c8f7d0578 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -78,7 +78,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py
index 455ba5a9293d..5d10345304ab 100644
--- a/examples/dreambooth/train_dreambooth_sd3.py
+++ b/examples/dreambooth/train_dreambooth_sd3.py
@@ -63,7 +63,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
@@ -77,6 +77,15 @@ def save_model_card(
validation_prompt=None,
repo_folder=None,
):
+ if "large" in base_model:
+ model_variant = "SD3.5-Large"
+ license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
+ variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"]
+ else:
+ model_variant = "SD3"
+ license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
+ variant_tags = ["sd3", "sd3-diffusers"]
+
widget_dict = []
if images is not None:
for i, image in enumerate(images):
@@ -86,7 +95,7 @@ def save_model_card(
)
model_description = f"""
-# SD3 DreamBooth - {repo_id}
+# {model_variant} DreamBooth - {repo_id}
@@ -113,7 +122,7 @@ def save_model_card(
## License
-Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`.
+Please adhere to the licensing terms as described `[here]({license_url})`.
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
@@ -128,10 +137,9 @@ def save_model_card(
"text-to-image",
"diffusers-training",
"diffusers",
- "sd3",
- "sd3-diffusers",
"template:sd-lora",
]
+ tags += variant_tags
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py
index 3cb0c6702599..125368841fa8 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py
@@ -57,7 +57,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
index c88be6d16d88..4cb9f0e1c544 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
@@ -60,7 +60,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
index 9caa3694d636..40016f797341 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
@@ -52,7 +52,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
index 23f5d342b396..3ec622c09239 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
@@ -46,7 +46,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
index 6ed3377db131..fbd843bc3307 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
@@ -46,7 +46,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
index 448429444448..c264a4ce8c7c 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
@@ -51,7 +51,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
index 02a064fa81ed..e694d709360c 100644
--- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py
+++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
@@ -60,7 +60,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py
index 684bf352a6c1..6857df61d0c2 100644
--- a/examples/text_to_image/train_text_to_image.py
+++ b/examples/text_to_image/train_text_to_image.py
@@ -57,7 +57,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py
index 4a80067d693d..712bc34429a0 100644
--- a/examples/text_to_image/train_text_to_image_flax.py
+++ b/examples/text_to_image/train_text_to_image_flax.py
@@ -49,7 +49,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py
index 379519b4c812..5f432fcc7adf 100644
--- a/examples/text_to_image/train_text_to_image_lora.py
+++ b/examples/text_to_image/train_text_to_image_lora.py
@@ -56,7 +56,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py
index fe098c8638d5..9a4fa23fada3 100644
--- a/examples/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -68,7 +68,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py
index bcf0fa9eb0ac..b34feb6f715c 100644
--- a/examples/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_sdxl.py
@@ -55,7 +55,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index 6b710531836b..43e8bf4e9072 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -81,7 +81,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py
index ee7b1580d145..fff633e75684 100644
--- a/examples/textual_inversion/textual_inversion_flax.py
+++ b/examples/textual_inversion/textual_inversion_flax.py
@@ -56,7 +56,7 @@
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py
index a4629f0f43d6..3a9da9fb11df 100644
--- a/examples/textual_inversion/textual_inversion_sdxl.py
+++ b/examples/textual_inversion/textual_inversion_sdxl.py
@@ -76,7 +76,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py
index 1f5e1de240cb..a80e4c55190d 100644
--- a/examples/unconditional_image_generation/train_unconditional.py
+++ b/examples/unconditional_image_generation/train_unconditional.py
@@ -29,7 +29,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py
index d16dce921896..b56e39847983 100644
--- a/examples/vqgan/train_vqgan.py
+++ b/examples/vqgan/train_vqgan.py
@@ -50,7 +50,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
index b4c9a44bb5b2..d57d910599ee 100644
--- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
@@ -50,7 +50,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
index eba8de69203a..2d9df8387333 100644
--- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
+++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
@@ -51,7 +51,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.31.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/scripts/convert_sd3_to_diffusers.py b/scripts/convert_sd3_to_diffusers.py
index 4f32745dae75..1f9c434b39d0 100644
--- a/scripts/convert_sd3_to_diffusers.py
+++ b/scripts/convert_sd3_to_diffusers.py
@@ -16,10 +16,9 @@
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", type=str)
parser.add_argument("--output_path", type=str)
-parser.add_argument("--dtype", type=str, default="fp16")
+parser.add_argument("--dtype", type=str)
args = parser.parse_args()
-dtype = torch.float16 if args.dtype == "fp16" else torch.float32
def load_original_checkpoint(ckpt_path):
@@ -40,7 +39,9 @@ def swap_scale_shift(weight, dim):
return new_weight
-def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_layers, caption_projection_dim):
+def convert_sd3_transformer_checkpoint_to_diffusers(
+ original_state_dict, num_layers, caption_projection_dim, dual_attention_layers, has_qk_norm
+):
converted_state_dict = {}
# Positional and patch embeddings.
@@ -110,6 +111,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
+ # qk norm
+ if has_qk_norm:
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = original_state_dict.pop(
+ f"joint_blocks.{i}.x_block.attn.ln_q.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = original_state_dict.pop(
+ f"joint_blocks.{i}.x_block.attn.ln_k.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = original_state_dict.pop(
+ f"joint_blocks.{i}.context_block.attn.ln_q.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = original_state_dict.pop(
+ f"joint_blocks.{i}.context_block.attn.ln_k.weight"
+ )
+
# output projections.
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.attn.proj.weight"
@@ -125,6 +141,39 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay
f"joint_blocks.{i}.context_block.attn.proj.bias"
)
+ # attn2
+ if i in dual_attention_layers:
+ # Q, K, V
+ sample_q2, sample_k2, sample_v2 = torch.chunk(
+ original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
+ )
+ sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
+ original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])
+
+ # qk norm
+ if has_qk_norm:
+ converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = original_state_dict.pop(
+ f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = original_state_dict.pop(
+ f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
+ )
+
+ # output projections.
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = original_state_dict.pop(
+ f"joint_blocks.{i}.x_block.attn2.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = original_state_dict.pop(
+ f"joint_blocks.{i}.x_block.attn2.proj.bias"
+ )
+
# norms.
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = original_state_dict.pop(
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
@@ -195,25 +244,79 @@ def is_vae_in_checkpoint(original_state_dict):
)
+def get_attn2_layers(state_dict):
+ attn2_layers = []
+ for key in state_dict.keys():
+ if "attn2." in key:
+ # Extract the layer number from the key
+ layer_num = int(key.split(".")[1])
+ attn2_layers.append(layer_num)
+ return tuple(sorted(set(attn2_layers)))
+
+
+def get_pos_embed_max_size(state_dict):
+ num_patches = state_dict["pos_embed"].shape[1]
+ pos_embed_max_size = int(num_patches**0.5)
+ return pos_embed_max_size
+
+
+def get_caption_projection_dim(state_dict):
+ caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
+ return caption_projection_dim
+
+
def main(args):
original_ckpt = load_original_checkpoint(args.checkpoint_path)
+ original_dtype = next(iter(original_ckpt.values())).dtype
+
+ # Initialize dtype with a default value
+ dtype = None
+
+ if args.dtype is None:
+ dtype = original_dtype
+ elif args.dtype == "fp16":
+ dtype = torch.float16
+ elif args.dtype == "bf16":
+ dtype = torch.bfloat16
+ elif args.dtype == "fp32":
+ dtype = torch.float32
+ else:
+ raise ValueError(f"Unsupported dtype: {args.dtype}")
+
+ if dtype != original_dtype:
+ print(
+ f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution."
+ )
+
num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401
- caption_projection_dim = 1536
+
+ caption_projection_dim = get_caption_projection_dim(original_ckpt)
+
+ # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
+ attn2_layers = get_attn2_layers(original_ckpt)
+
+ # sd3.5 use qk norm("rms_norm")
+ has_qk_norm = any("ln_q" in key for key in original_ckpt.keys())
+
+ # sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192
+ pos_embed_max_size = get_pos_embed_max_size(original_ckpt)
converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers(
- original_ckpt, num_layers, caption_projection_dim
+ original_ckpt, num_layers, caption_projection_dim, attn2_layers, has_qk_norm
)
with CTX():
transformer = SD3Transformer2DModel(
- sample_size=64,
+ sample_size=128,
patch_size=2,
in_channels=16,
joint_attention_dim=4096,
num_layers=num_layers,
caption_projection_dim=caption_projection_dim,
- num_attention_heads=24,
- pos_embed_max_size=192,
+ num_attention_heads=num_layers,
+ pos_embed_max_size=pos_embed_max_size,
+ qk_norm="rms_norm" if has_qk_norm else None,
+ dual_attention_layers=attn2_layers,
)
if is_accelerate_available():
load_model_dict_into_meta(transformer, converted_transformer_state_dict)
diff --git a/setup.py b/setup.py
index 89e18cab629b..d82ecad86771 100644
--- a/setup.py
+++ b/setup.py
@@ -130,7 +130,7 @@
"regex!=2019.12.17",
"requests",
"tensorboard",
- "torch>=1.4",
+ "torch>=1.4,<2.5.0",
"torchvision",
"transformers>=4.41.2",
"urllib3<=2.0.0",
@@ -254,7 +254,7 @@ def run(self):
setup(
name="diffusers",
- version="0.31.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ version="0.32.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 019d744730ab..789458a26299 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.31.0.dev0"
+__version__ = "0.32.0.dev0"
from typing import TYPE_CHECKING
@@ -31,6 +31,7 @@
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
+ "quantizers.quantization_config": ["BitsAndBytesConfig"],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
@@ -124,7 +125,6 @@
"VQModel",
]
)
-
_import_structure["optimization"] = [
"get_constant_schedule",
"get_constant_schedule_with_warmup",
@@ -156,6 +156,7 @@
"StableDiffusionMixin",
]
)
+ _import_structure["quantizers"] = ["DiffusersQuantizer"]
_import_structure["schedulers"].extend(
[
"AmusedScheduler",
@@ -256,6 +257,7 @@
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"CLIPImageProjection",
+ "CogVideoXFunControlPipeline",
"CogVideoXImageToVideoPipeline",
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
@@ -537,6 +539,7 @@
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
+ from .quantizers.quantization_config import BitsAndBytesConfig
try:
if not is_onnx_available():
@@ -631,6 +634,7 @@
ScoreSdeVePipeline,
StableDiffusionMixin,
)
+ from .quantizers import DiffusersQuantizer
from .schedulers import (
AmusedScheduler,
CMStochasticIterativeScheduler,
@@ -711,6 +715,7 @@
AudioLDMPipeline,
AuraFlowPipeline,
CLIPImageProjection,
+ CogVideoXFunControlPipeline,
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
diff --git a/src/diffusers/callbacks.py b/src/diffusers/callbacks.py
index 38542407e31f..4b8b15368c47 100644
--- a/src/diffusers/callbacks.py
+++ b/src/diffusers/callbacks.py
@@ -97,13 +97,17 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
class SDXLCFGCutoffCallback(PipelineCallback):
"""
- Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
- `cutoff_step_index`), this callback will disable the CFG.
+ Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""
- tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
+ tensor_inputs = [
+ "prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ ]
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
@@ -129,6 +133,55 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
+
+ return callback_kwargs
+
+
+class SDXLControlnetCFGCutoffCallback(PipelineCallback):
+ """
+ Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
+
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
+ """
+
+ tensor_inputs = [
+ "prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ "image",
+ ]
+
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
+ cutoff_step_ratio = self.config.cutoff_step_ratio
+ cutoff_step_index = self.config.cutoff_step_index
+
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
+ cutoff_step = (
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
+ )
+
+ if step_index == cutoff_step:
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
+
+ add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
+ add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
+
+ add_time_ids = callback_kwargs[self.tensor_inputs[2]]
+ add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
+
+ # For Controlnet
+ image = callback_kwargs[self.tensor_inputs[3]]
+ image = image[-1:]
+
+ pipeline._guidance_scale = 0.0
+
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
+ callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
+ callback_kwargs[self.tensor_inputs[2]] = add_time_ids
+ callback_kwargs[self.tensor_inputs[3]] = image
+
return callback_kwargs
diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py
index 3dccd785cae4..11d45dc64d97 100644
--- a/src/diffusers/configuration_utils.py
+++ b/src/diffusers/configuration_utils.py
@@ -510,6 +510,9 @@ def extract_init_dict(cls, config_dict, **kwargs):
# remove private attributes
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
+ # remove quantization_config
+ config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config"}
+
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
init_dict = {}
for key in expected_keys:
@@ -586,10 +589,19 @@ def to_json_saveable(value):
value = value.as_posix()
return value
+ if "quantization_config" in config_dict:
+ config_dict["quantization_config"] = (
+ config_dict.quantization_config.to_dict()
+ if not isinstance(config_dict.quantization_config, dict)
+ else config_dict.quantization_config
+ )
+
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
# Don't save "_ignore_files" or "_use_default_values"
config_dict.pop("_ignore_files", None)
config_dict.pop("_use_default_values", None)
+ # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
+ _ = config_dict.pop("_pre_quantization_dtype", None)
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py
index 9e7bf242eca7..0e421b71e48d 100644
--- a/src/diffusers/dependency_versions_table.py
+++ b/src/diffusers/dependency_versions_table.py
@@ -38,7 +38,7 @@
"regex": "regex!=2019.12.17",
"requests": "requests",
"tensorboard": "tensorboard",
- "torch": "torch>=1.4",
+ "torch": "torch>=1.4,<2.5.0",
"torchvision": "torchvision",
"transformers": "transformers>=4.41.2",
"urllib3": "urllib3<=2.0.0",
diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index 236fbd0c2295..d1bad8b5a7cd 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -75,6 +75,7 @@
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
"stable_cascade_stage_c": "clip_txt_mapper.weight",
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
+ "sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
@@ -113,6 +114,9 @@
"sd3": {
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
},
+ "sd35_large": {
+ "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large",
+ },
"animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
@@ -504,9 +508,12 @@ def infer_diffusers_model_type(checkpoint):
):
model_type = "stable_cascade_stage_b"
- elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
+ elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216:
model_type = "sd3"
+ elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint:
+ model_type = "sd35_large"
+
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
model_type = "animatediff_scribble"
@@ -1670,6 +1677,22 @@ def swap_scale_shift(weight, dim):
return new_weight
+def get_attn2_layers(state_dict):
+ attn2_layers = []
+ for key in state_dict.keys():
+ if "attn2." in key:
+ # Extract the layer number from the key
+ layer_num = int(key.split(".")[1])
+ attn2_layers.append(layer_num)
+
+ return tuple(sorted(set(attn2_layers)))
+
+
+def get_caption_projection_dim(state_dict):
+ caption_projection_dim = state_dict["context_embedder.weight"].shape[0]
+ return caption_projection_dim
+
+
def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
@@ -1678,7 +1701,10 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
- caption_projection_dim = 1536
+ dual_attention_layers = get_attn2_layers(checkpoint)
+
+ caption_projection_dim = get_caption_projection_dim(checkpoint)
+ has_qk_norm = any("ln_q" in key for key in checkpoint.keys())
# Positional and patch embeddings.
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
@@ -1735,6 +1761,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
+ # qk norm
+ if has_qk_norm:
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn.ln_q.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn.ln_k.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.context_block.attn.ln_q.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.context_block.attn.ln_k.weight"
+ )
+
# output projections.
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.attn.proj.weight"
@@ -1750,6 +1791,38 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
f"joint_blocks.{i}.context_block.attn.proj.bias"
)
+ if i in dual_attention_layers:
+ # Q, K, V
+ sample_q2, sample_k2, sample_v2 = torch.chunk(
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
+ )
+ sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk(
+ checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2])
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias])
+
+ # qk norm
+ if has_qk_norm:
+ converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn2.ln_q.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn2.ln_k.weight"
+ )
+
+ # output projections.
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn2.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
+ f"joint_blocks.{i}.x_block.attn2.proj.bias"
+ )
+
# norms.
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
diff --git a/src/diffusers/loaders/textual_inversion.py b/src/diffusers/loaders/textual_inversion.py
index 574b89233cc1..30098c955d6b 100644
--- a/src/diffusers/loaders/textual_inversion.py
+++ b/src/diffusers/loaders/textual_inversion.py
@@ -561,6 +561,8 @@ def unload_textual_inversion(
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
key_id += 1
tokenizer._update_trie()
+ # set correct total vocab size after removing tokens
+ tokenizer._update_total_vocab_size()
# Delete from text encoder
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index 84db0d061768..02ed1f965abf 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -22,7 +22,7 @@
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
from .attention_processor import Attention, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
-from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
+from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
logger = logging.get_logger(__name__)
@@ -100,13 +100,25 @@ class JointTransformerBlock(nn.Module):
processing of `context` conditions.
"""
- def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ context_pre_only: bool = False,
+ qk_norm: Optional[str] = None,
+ use_dual_attention: bool = False,
+ ):
super().__init__()
+ self.use_dual_attention = use_dual_attention
self.context_pre_only = context_pre_only
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
- self.norm1 = AdaLayerNormZero(dim)
+ if use_dual_attention:
+ self.norm1 = SD35AdaLayerNormZeroX(dim)
+ else:
+ self.norm1 = AdaLayerNormZero(dim)
if context_norm_type == "ada_norm_continous":
self.norm1_context = AdaLayerNormContinuous(
@@ -118,12 +130,14 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl
raise ValueError(
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
)
+
if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
+
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
@@ -134,8 +148,25 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl
context_pre_only=context_pre_only,
bias=True,
processor=processor,
+ qk_norm=qk_norm,
+ eps=1e-6,
)
+ if use_dual_attention:
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ processor=processor,
+ qk_norm=qk_norm,
+ eps=1e-6,
+ )
+ else:
+ self.attn2 = None
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
@@ -159,7 +190,12 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
):
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+ if self.use_dual_attention:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
+ hidden_states, emb=temb
+ )
+ else:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
if self.context_pre_only:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
@@ -177,6 +213,11 @@ def forward(
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
+ if self.use_dual_attention:
+ attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
+ attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
+ hidden_states = hidden_states + attn_output2
+
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self._chunk_size is not None:
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index d333590982e3..e735c4ee7d17 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -193,7 +193,7 @@ def __init__(
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
else:
- raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
if cross_attention_norm is None:
self.norm_cross = None
@@ -250,6 +250,10 @@ def __init__(
elif qk_norm == "rms_norm":
self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps)
+ else:
+ raise ValueError(
+ f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
+ )
else:
self.norm_added_q = None
self.norm_added_k = None
@@ -1050,61 +1054,72 @@ def __call__(
) -> torch.FloatTensor:
residual = hidden_states
- input_ndim = hidden_states.ndim
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
- context_input_ndim = encoder_hidden_states.ndim
- if context_input_ndim == 4:
- batch_size, channel, height, width = encoder_hidden_states.shape
- encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size = encoder_hidden_states.shape[0]
+ batch_size = hidden_states.shape[0]
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- # attention
- query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
- key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
- value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
-
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # `context` projections.
+ if encoder_hidden_states is not None:
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
+ batch_size, -1, attn.heads, head_dim
+ ).transpose(1, 2)
+
+ if attn.norm_added_q is not None:
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
+ if attn.norm_added_k is not None:
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
+
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
- # Split the attention outputs.
- hidden_states, encoder_hidden_states = (
- hidden_states[:, : residual.shape[1]],
- hidden_states[:, residual.shape[1] :],
- )
+ if encoder_hidden_states is not None:
+ # Split the attention outputs.
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : residual.shape[1]],
+ hidden_states[:, residual.shape[1] :],
+ )
+ if not attn.context_pre_only:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
- if not attn.context_pre_only:
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
- if context_input_ndim == 4:
- encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- return hidden_states, encoder_hidden_states
+ if encoder_hidden_states is not None:
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
class PAGJointAttnProcessor2_0:
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
index 7834206ddb4a..68b49d72acc5 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
@@ -1182,7 +1182,8 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
frame_batch_size = self.num_sample_frames_batch_size
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
- num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
+ num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
enc = []
@@ -1330,7 +1331,8 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
row = []
for j in range(0, width, overlap_width):
# Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
- num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
+ # As the extra single frame is handled inside the loop, it is not required to round up here.
+ num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
time = []
@@ -1409,7 +1411,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
- num_batches = num_frames // frame_batch_size
+ num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
time = []
diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py
index c9eb664443b5..932a94571107 100644
--- a/src/diffusers/models/model_loading_utils.py
+++ b/src/diffusers/models/model_loading_utils.py
@@ -25,6 +25,7 @@
import torch
from huggingface_hub.utils import EntryNotFoundError
+from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
@@ -54,11 +55,36 @@
# Adapted from `transformers` (see modeling_utils.py)
-def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
+def _determine_device_map(
+ model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None
+):
if isinstance(device_map, str):
+ special_dtypes = {}
+ if hf_quantizer is not None:
+ special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype))
+ special_dtypes.update(
+ {
+ name: torch.float32
+ for name, _ in model.named_parameters()
+ if any(m in name for m in keep_in_fp32_modules)
+ }
+ )
+
+ target_dtype = torch_dtype
+ if hf_quantizer is not None:
+ target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)
+
no_split_modules = model._get_no_split_modules(device_map)
device_map_kwargs = {"no_split_module_classes": no_split_modules}
+ if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
+ device_map_kwargs["special_dtypes"] = special_dtypes
+ elif len(special_dtypes) > 0:
+ logger.warning(
+ "This model has some weights that should be kept in higher precision, you need to upgrade "
+ "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
+ )
+
if device_map != "sequential":
max_memory = get_balanced_memory(
model,
@@ -70,8 +96,14 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_
else:
max_memory = get_max_memory(max_memory)
+ if hf_quantizer is not None:
+ max_memory = hf_quantizer.adjust_max_memory(max_memory)
+
device_map_kwargs["max_memory"] = max_memory
- device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
+ device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
+
+ if hf_quantizer is not None:
+ hf_quantizer.validate_environment(device_map=device_map)
return device_map
@@ -100,6 +132,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
+ # TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change
+ # when refactoring the _merge_sharded_checkpoints() method later.
+ if isinstance(checkpoint_file, dict):
+ return checkpoint_file
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
@@ -137,29 +173,67 @@ def load_model_dict_into_meta(
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
model_name_or_path: Optional[str] = None,
+ hf_quantizer=None,
+ keep_in_fp32_modules=None,
) -> List[str]:
- device = device or torch.device("cpu")
+ if hf_quantizer is None:
+ device = device or torch.device("cpu")
dtype = dtype or torch.float32
+ is_quantized = hf_quantizer is not None
+ is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
-
- unexpected_keys = []
empty_state_dict = model.state_dict()
+ unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
+
for param_name, param in state_dict.items():
if param_name not in empty_state_dict:
- unexpected_keys.append(param_name)
continue
+ set_module_kwargs = {}
+ # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
+ # in int/uint/bool and not cast them.
+ # TODO: revisit cases when param.dtype == torch.float8_e4m3fn
+ if torch.is_floating_point(param):
+ if (
+ keep_in_fp32_modules is not None
+ and any(
+ module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
+ )
+ and dtype == torch.float16
+ ):
+ param = param.to(torch.float32)
+ if accepts_dtype:
+ set_module_kwargs["dtype"] = torch.float32
+ else:
+ param = param.to(dtype)
+ if accepts_dtype:
+ set_module_kwargs["dtype"] = dtype
+
+ # bnb params are flattened.
if empty_state_dict[param_name].shape != param.shape:
- model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
- raise ValueError(
- f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
- )
+ if (
+ is_quant_method_bnb
+ and hf_quantizer.pre_quantized
+ and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
+ ):
+ hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
+ elif not is_quant_method_bnb:
+ model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
+ raise ValueError(
+ f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
+ )
- if accepts_dtype:
- set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
+ if is_quantized and (
+ hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
+ ):
+ hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
- set_module_tensor_to_device(model, param_name, device, value=param)
+ if accepts_dtype:
+ set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
+ else:
+ set_module_tensor_to_device(model, param_name, device, value=param)
+
return unexpected_keys
@@ -231,6 +305,35 @@ def _fetch_index_file(
return index_file
+# Adapted from
+# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
+def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
+ weight_map = sharded_metadata.get("weight_map", None)
+ if weight_map is None:
+ raise KeyError("'weight_map' key not found in the shard index file.")
+
+ # Collect all unique safetensors files from weight_map
+ files_to_load = set(weight_map.values())
+ is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
+ merged_state_dict = {}
+
+ # Load tensors from each unique file
+ for file_name in files_to_load:
+ part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
+ if not os.path.exists(part_file_path):
+ raise FileNotFoundError(f"Part file {file_name} not found.")
+
+ if is_safetensors:
+ with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
+ for tensor_key in f.keys():
+ if tensor_key in weight_map:
+ merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
+ else:
+ merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
+
+ return merged_state_dict
+
+
def _fetch_index_file_legacy(
is_local,
pretrained_model_name_or_path,
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index ad3433889fca..4a486fd4ce40 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -14,13 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import inspect
import itertools
import json
import os
import re
from collections import OrderedDict
-from functools import partial
+from functools import partial, wraps
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
@@ -31,6 +32,8 @@
from torch import Tensor, nn
from .. import __version__
+from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
+from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME,
@@ -43,6 +46,8 @@
_get_model_file,
deprecate,
is_accelerate_available,
+ is_bitsandbytes_available,
+ is_bitsandbytes_version,
is_torch_version,
logging,
)
@@ -56,6 +61,7 @@
_fetch_index_file,
_fetch_index_file_legacy,
_load_state_dict_into_model,
+ _merge_sharded_checkpoints,
load_model_dict_into_meta,
load_state_dict,
)
@@ -125,6 +131,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_supports_gradient_checkpointing = False
_keys_to_ignore_on_load_unexpected = None
_no_split_modules = None
+ _keep_in_fp32_modules = None
def __init__(self):
super().__init__()
@@ -308,6 +315,19 @@ def save_pretrained(
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
+ hf_quantizer = getattr(self, "hf_quantizer", None)
+ if hf_quantizer is not None:
+ quantization_serializable = (
+ hf_quantizer is not None
+ and isinstance(hf_quantizer, DiffusersQuantizer)
+ and hf_quantizer.is_serializable
+ )
+ if not quantization_serializable:
+ raise ValueError(
+ f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
+ " the logger on the traceback to understand the reason why the quantized model is not serializable."
+ )
+
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
@@ -402,6 +422,18 @@ def save_pretrained(
create_pr=create_pr,
)
+ def dequantize(self):
+ """
+ Potentially dequantize the model in case it has been quantized by a quantization method that support
+ dequantization.
+ """
+ hf_quantizer = getattr(self, "hf_quantizer", None)
+
+ if hf_quantizer is None:
+ raise ValueError("You need to first quantize your model in order to dequantize it")
+
+ return hf_quantizer.dequantize(self)
+
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
@@ -524,6 +556,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
+ quantization_config = kwargs.pop("quantization_config", None)
allow_pickle = False
if use_safetensors is None:
@@ -618,6 +651,60 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
user_agent=user_agent,
**kwargs,
)
+ # no in-place modification of the original config.
+ config = copy.deepcopy(config)
+
+ # determine initial quantization config.
+ #######################################
+ pre_quantized = "quantization_config" in config and config["quantization_config"] is not None
+ if pre_quantized or quantization_config is not None:
+ if pre_quantized:
+ config["quantization_config"] = DiffusersAutoQuantizer.merge_quantization_configs(
+ config["quantization_config"], quantization_config
+ )
+ else:
+ config["quantization_config"] = quantization_config
+ hf_quantizer = DiffusersAutoQuantizer.from_config(
+ config["quantization_config"], pre_quantized=pre_quantized
+ )
+ else:
+ hf_quantizer = None
+
+ if hf_quantizer is not None:
+ if device_map is not None:
+ raise NotImplementedError(
+ "Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
+ )
+ hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
+ torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
+
+ # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
+ user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
+
+ # Force-set to `True` for more mem efficiency
+ if low_cpu_mem_usage is None:
+ low_cpu_mem_usage = True
+ logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.")
+ elif not low_cpu_mem_usage:
+ raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
+
+ # Check if `_keep_in_fp32_modules` is not None
+ use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
+ (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
+ )
+ if use_keep_in_fp32_modules:
+ keep_in_fp32_modules = cls._keep_in_fp32_modules
+ if not isinstance(keep_in_fp32_modules, list):
+ keep_in_fp32_modules = [keep_in_fp32_modules]
+
+ if low_cpu_mem_usage is None:
+ low_cpu_mem_usage = True
+ logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
+ elif not low_cpu_mem_usage:
+ raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
+ else:
+ keep_in_fp32_modules = []
+ #######################################
# Determine if we're loading from a directory of sharded checkpoints.
is_sharded = False
@@ -684,6 +771,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder or "",
)
+ if hf_quantizer is not None:
+ model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
+ logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
+ is_sharded = False
elif use_safetensors and not is_sharded:
try:
@@ -729,13 +820,30 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
with accelerate.init_empty_weights():
model = cls.from_config(config, **unused_kwargs)
+ if hf_quantizer is not None:
+ hf_quantizer.preprocess_model(
+ model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
+ )
+
# if device_map is None, load the state dict and move the params from meta device to the cpu
if device_map is None and not is_sharded:
- param_device = "cpu"
+ # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
+ # It would error out during the `validate_environment()` call above in the absence of cuda.
+ is_quant_method_bnb = (
+ getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
+ )
+ if hf_quantizer is None:
+ param_device = "cpu"
+ # TODO (sayakpaul, SunMarc): remove this after model loading refactor
+ elif is_quant_method_bnb:
+ param_device = torch.cuda.current_device()
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
+
# move the params from meta device to cpu
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
+ if hf_quantizer is not None:
+ missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
if len(missing_keys) > 0:
raise ValueError(
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
@@ -750,6 +858,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
device=param_device,
dtype=torch_dtype,
model_name_or_path=pretrained_model_name_or_path,
+ hf_quantizer=hf_quantizer,
+ keep_in_fp32_modules=keep_in_fp32_modules,
)
if cls._keys_to_ignore_on_load_unexpected is not None:
@@ -765,7 +875,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
force_hook = True
- device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
+ device_map = _determine_device_map(
+ model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
+ )
if device_map is None and is_sharded:
# we load the parameters on the cpu
device_map = {"": "cpu"}
@@ -843,14 +955,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
"error_msgs": error_msgs,
}
+ if hf_quantizer is not None:
+ hf_quantizer.postprocess_model(model)
+ model.hf_quantizer = hf_quantizer
+
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
- elif torch_dtype is not None:
+ # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will
+ # completely lose the effectivity of `use_keep_in_fp32_modules`.
+ elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules:
model = model.to(torch_dtype)
- model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+ if hf_quantizer is not None:
+ # We also make sure to purge `_pre_quantization_dtype` when we serialize
+ # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable.
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path, _pre_quantization_dtype=torch_dtype)
+ else:
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
@@ -859,6 +982,76 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
return model
+ # Adapted from `transformers`.
+ @wraps(torch.nn.Module.cuda)
+ def cuda(self, *args, **kwargs):
+ # Checks if the model has been loaded in 4-bit or 8-bit with BNB
+ if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
+ if getattr(self, "is_loaded_in_8bit", False):
+ raise ValueError(
+ "Calling `cuda()` is not supported for `8-bit` quantized models. "
+ " Please use the model as it is, since the model has already been set to the correct devices."
+ )
+ elif is_bitsandbytes_version("<", "0.43.2"):
+ raise ValueError(
+ "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
+ f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
+ )
+ return super().cuda(*args, **kwargs)
+
+ # Adapted from `transformers`.
+ @wraps(torch.nn.Module.to)
+ def to(self, *args, **kwargs):
+ dtype_present_in_args = "dtype" in kwargs
+
+ if not dtype_present_in_args:
+ for arg in args:
+ if isinstance(arg, torch.dtype):
+ dtype_present_in_args = True
+ break
+
+ # Checks if the model has been loaded in 4-bit or 8-bit with BNB
+ if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
+ if dtype_present_in_args:
+ raise ValueError(
+ "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
+ " desired `dtype` by passing the correct `torch_dtype` argument."
+ )
+
+ if getattr(self, "is_loaded_in_8bit", False):
+ raise ValueError(
+ "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
+ " model has already been set to the correct devices and casted to the correct `dtype`."
+ )
+ elif is_bitsandbytes_version("<", "0.43.2"):
+ raise ValueError(
+ "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
+ f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
+ )
+ return super().to(*args, **kwargs)
+
+ # Taken from `transformers`.
+ def half(self, *args):
+ # Checks if the model is quantized
+ if getattr(self, "is_quantized", False):
+ raise ValueError(
+ "`.half()` is not supported for quantized model. Please use the model as it is, since the"
+ " model has already been cast to the correct `dtype`."
+ )
+ else:
+ return super().half(*args)
+
+ # Taken from `transformers`.
+ def float(self, *args):
+ # Checks if the model is quantized
+ if getattr(self, "is_quantized", False):
+ raise ValueError(
+ "`.float()` is not supported for quantized model. Please use the model as it is, since the"
+ " model has already been cast to the correct `dtype`."
+ )
+ else:
+ return super().float(*args)
+
@classmethod
def _load_pretrained_model(
cls,
@@ -1041,19 +1234,63 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool
859520964
```
"""
+ is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
+
+ if is_loaded_in_4bit:
+ if is_bitsandbytes_available():
+ import bitsandbytes as bnb
+ else:
+ raise ValueError(
+ "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
+ " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. "
+ )
if exclude_embeddings:
embedding_param_names = [
- f"{name}.weight"
- for name, module_type in self.named_modules()
- if isinstance(module_type, torch.nn.Embedding)
+ f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
]
- non_embedding_parameters = [
+ total_parameters = [
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
]
- return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else:
- return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
+ total_parameters = list(self.parameters())
+
+ total_numel = []
+
+ for param in total_parameters:
+ if param.requires_grad or not only_trainable:
+ # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
+ # used for the 4bit quantization (uint8 tensors are stored)
+ if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
+ if hasattr(param, "element_size"):
+ num_bytes = param.element_size()
+ elif hasattr(param, "quant_storage"):
+ num_bytes = param.quant_storage.itemsize
+ else:
+ num_bytes = 1
+ total_numel.append(param.numel() * 2 * num_bytes)
+ else:
+ total_numel.append(param.numel())
+
+ return sum(total_numel)
+
+ def get_memory_footprint(self, return_buffers=True):
+ r"""
+ Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
+ Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
+ PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
+
+ Arguments:
+ return_buffers (`bool`, *optional*, defaults to `True`):
+ Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
+ are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
+ norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
+ """
+ mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
+ if return_buffers:
+ mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
+ mem = mem + mem_bufs
+ return mem
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
deprecated_attention_block_paths = []
diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py
index 21e9d3cd6fc5..029c147fcbac 100644
--- a/src/diffusers/models/normalization.py
+++ b/src/diffusers/models/normalization.py
@@ -97,6 +97,40 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
).to(origin_dtype)
+class SD35AdaLayerNormZeroX(nn.Module):
+ r"""
+ Norm layer adaptive layer norm zero (AdaLN-Zero).
+
+ Parameters:
+ embedding_dim (`int`): The size of each embedding vector.
+ num_embeddings (`int`): The size of the embeddings dictionary.
+ """
+
+ def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None:
+ super().__init__()
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias)
+ if norm_type == "layer_norm":
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+ else:
+ raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ emb: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, ...]:
+ emb = self.linear(self.silu(emb))
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
+ 9, dim=1
+ )
+ norm_hidden_states = self.norm(hidden_states)
+ hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
+ norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
+ return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2
+
+
class AdaLayerNormZero(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py
index 9376c91d0756..b28350b8ed9c 100644
--- a/src/diffusers/models/transformers/transformer_sd3.py
+++ b/src/diffusers/models/transformers/transformer_sd3.py
@@ -13,7 +13,7 @@
# limitations under the License.
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -69,6 +69,10 @@ def __init__(
pooled_projection_dim: int = 2048,
out_channels: int = 16,
pos_embed_max_size: int = 96,
+ dual_attention_layers: Tuple[
+ int, ...
+ ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
+ qk_norm: Optional[str] = None,
):
super().__init__()
default_out_channels = in_channels
@@ -97,6 +101,8 @@ def __init__(
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
context_pre_only=i == num_layers - 1,
+ qk_norm=qk_norm,
+ use_dual_attention=True if i in dual_attention_layers else False,
)
for i in range(self.config.num_layers)
]
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index d7ff34310beb..7366520f4692 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -144,6 +144,7 @@
"CogVideoXPipeline",
"CogVideoXImageToVideoPipeline",
"CogVideoXVideoToVideoPipeline",
+ "CogVideoXFunControlPipeline",
]
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
_import_structure["controlnet"].extend(
@@ -470,7 +471,12 @@
)
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
- from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline
+ from .cogvideo import (
+ CogVideoXFunControlPipeline,
+ CogVideoXImageToVideoPipeline,
+ CogVideoXPipeline,
+ CogVideoXVideoToVideoPipeline,
+ )
from .cogview3 import CogView3PlusPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
index 6fe5c5604c86..181448fc2f5e 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
@@ -113,9 +113,21 @@
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -135,7 +147,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
index de5a16b05b40..c7ff97ce4226 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
@@ -119,7 +119,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
index 1d26f95a2f58..9a93f1d28d35 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
@@ -131,7 +131,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
index 1e3ba67adf5d..ae55b3cab17e 100644
--- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
+++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
@@ -54,7 +54,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/cogvideo/__init__.py b/src/diffusers/pipelines/cogvideo/__init__.py
index bd60fcea9994..e4fa1dda53d3 100644
--- a/src/diffusers/pipelines/cogvideo/__init__.py
+++ b/src/diffusers/pipelines/cogvideo/__init__.py
@@ -23,6 +23,7 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
+ _import_structure["pipeline_cogvideox_fun_control"] = ["CogVideoXFunControlPipeline"]
_import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"]
_import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"]
@@ -35,6 +36,7 @@
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_cogvideox import CogVideoXPipeline
+ from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline
from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
index fdafdbc7c019..44ed94333e4a 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
@@ -86,7 +86,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
@@ -521,7 +521,7 @@ def __call__(
num_frames (`int`, defaults to `48`):
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
- num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
+ num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
needs to be satisfied is that of divisibility mentioned above.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
new file mode 100644
index 000000000000..3655075bd519
--- /dev/null
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
@@ -0,0 +1,794 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI, Alibaba-PAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from PIL import Image
+from transformers import T5EncoderModel, T5Tokenizer
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import CogVideoXLoraLoaderMixin
+from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
+from ...models.embeddings import get_3d_rotary_pos_embed
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from .pipeline_output import CogVideoXPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import CogVideoXFunControlPipeline, DDIMScheduler
+ >>> from diffusers.utils import export_to_video, load_video
+
+ >>> pipe = CogVideoXFunControlPipeline.from_pretrained(
+ ... "alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.to("cuda")
+
+ >>> control_video = load_video(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
+ ... )
+ >>> prompt = (
+ ... "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and "
+ ... "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in "
+ ... "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, "
+ ... "moons, but the remainder of the scene is mostly realistic."
+ ... )
+
+ >>> video = pipe(prompt=prompt, control_video=control_video).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=8)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
+ r"""
+ Pipeline for controlled text-to-video generation using CogVideoX Fun.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. CogVideoX uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`CogVideoXTransformer3DModel`]):
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ _optional_components = []
+ model_cpu_offload_seq = "text_encoder->vae->transformer->vae"
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKLCogVideoX,
+ transformer: CogVideoXTransformer3DModel,
+ scheduler: KarrasDiffusionSchedulers,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+ self.vae_scale_factor_spatial = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
+ )
+ self.vae_scaling_factor_image = (
+ self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.prepare_latents
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ shape = (
+ batch_size,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ num_channels_latents,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Adapted from https://github.com/aigc-apps/CogVideoX-Fun/blob/2a93e5c14e02b2b5921d533fd59fc8c0ed69fb24/cogvideox/pipeline/pipeline_cogvideox_control.py#L366
+ def prepare_control_latents(
+ self, mask: Optional[torch.Tensor] = None, masked_image: Optional[torch.Tensor] = None
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if mask is not None:
+ masks = []
+ for i in range(mask.size(0)):
+ current_mask = mask[i].unsqueeze(0)
+ current_mask = self.vae.encode(current_mask)[0]
+ current_mask = current_mask.mode()
+ masks.append(current_mask)
+ mask = torch.cat(masks, dim=0)
+ mask = mask * self.vae.config.scaling_factor
+
+ if masked_image is not None:
+ mask_pixel_values = []
+ for i in range(masked_image.size(0)):
+ mask_pixel_value = masked_image[i].unsqueeze(0)
+ mask_pixel_value = self.vae.encode(mask_pixel_value)[0]
+ mask_pixel_value = mask_pixel_value.mode()
+ mask_pixel_values.append(mask_pixel_value)
+ masked_image_latents = torch.cat(mask_pixel_values, dim=0)
+ masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
+ else:
+ masked_image_latents = None
+
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
+ latents = 1 / self.vae_scaling_factor_image * latents
+
+ frames = self.vae.decode(latents).sample
+ return frames
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ control_video=None,
+ control_video_latents=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if control_video is not None and control_video_latents is not None:
+ raise ValueError(
+ "Cannot pass both `control_video` and `control_video_latents`. Please make sure to pass only one of these parameters."
+ )
+
+ def fuse_qkv_projections(self) -> None:
+ r"""Enables fused QKV projections."""
+ self.fusing_transformer = True
+ self.transformer.fuse_qkv_projections()
+
+ def unfuse_qkv_projections(self) -> None:
+ r"""Disable QKV projection fusion if enabled."""
+ if not self.fusing_transformer:
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.transformer.unfuse_qkv_projections()
+ self.fusing_transformer = False
+
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
+ def _prepare_rotary_positional_embeddings(
+ self,
+ height: int,
+ width: int,
+ num_frames: int,
+ device: torch.device,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+
+ grid_crops_coords = get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ )
+
+ freqs_cos = freqs_cos.to(device=device)
+ freqs_sin = freqs_sin.to(device=device)
+ return freqs_cos, freqs_sin
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ control_video: Optional[List[Image.Image]] = None,
+ height: int = 480,
+ width: int = 720,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 6,
+ use_dynamic_cfg: bool = False,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ control_video_latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 226,
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ control_video (`List[PIL.Image.Image]`):
+ The control video to condition the generation on. Must be a list of images/frames of the video. If not
+ provided, `control_video_latents` must be provided.
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 6.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ control_video_latents (`torch.Tensor`, *optional*):
+ Pre-generated control latents, sampled from a Gaussian distribution, to be used as inputs for
+ controlled video generation. If not provided, `control_video` must be provided.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `226`):
+ Maximum sequence length in encoded prompt. Must be consistent with
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ control_video,
+ control_video_latents,
+ )
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if control_video is not None and isinstance(control_video[0], Image.Image):
+ control_video = [control_video]
+
+ device = self._execution_device
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels // 2
+ num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2)
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ if control_video_latents is None:
+ control_video = self.video_processor.preprocess_video(control_video, height=height, width=width)
+ control_video = control_video.to(device=device, dtype=prompt_embeds.dtype)
+
+ _, control_video_latents = self.prepare_control_latents(None, control_video)
+ control_video_latents = control_video_latents.permute(0, 2, 1, 3, 4)
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Create rotary embeds if required
+ image_rotary_emb = (
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+ if self.transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ # for DPM-solver++
+ old_pred_original_sample = None
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ latent_control_input = (
+ torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents
+ )
+ latent_model_input = torch.cat([latent_model_input, latent_control_input], dim=2)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if use_dynamic_cfg:
+ self._guidance_scale = 1 + guidance_scale * (
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ )
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ latents = latents.to(prompt_embeds.dtype)
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if not output_type == "latent":
+ video = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return CogVideoXPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
index 975e8ed27db8..783dae569bec 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
@@ -88,7 +88,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
@@ -604,7 +604,7 @@ def __call__(
num_frames (`int`, defaults to `48`):
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
- num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
+ num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
needs to be satisfied is that of divisibility mentioned above.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
index 35f8f2fa0508..e1e816eca16d 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
@@ -94,7 +94,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index 7ae86421c45e..64fff61d2c32 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -57,7 +57,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
index b06ba4e7cba1..4e80aa5a4b6a 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
@@ -101,7 +101,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
index fa753bbfa98b..beb7d1c15d94 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
@@ -137,9 +137,21 @@ def retrieve_latents(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
index 67a96493a2eb..df652239df63 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
@@ -122,7 +122,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
@@ -242,6 +242,7 @@ class StableDiffusionXLControlNetPipeline(
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
+ "image",
]
def __init__(
@@ -1541,6 +1542,7 @@ def __call__(
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
+ image = callback_outputs.pop("image", image)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
diff --git a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
index 8e738c456802..0be580eb7f8a 100644
--- a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
+++ b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
@@ -141,9 +141,21 @@ def get_resize_crop_region_for_grid(src, tgt_size):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
index 02b4e46ab182..1764675df70b 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
@@ -83,7 +83,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
index 3b099721eb2d..170e1fc54deb 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
@@ -108,7 +108,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
index 0d4ca3799a53..1f7f4db728ad 100644
--- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
@@ -65,9 +65,21 @@
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -87,7 +99,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
index 54823937f58e..b3635be77ab4 100644
--- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
+++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
@@ -127,7 +127,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
index f2040b184eac..e8f01273828a 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -20,7 +20,7 @@
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import VaeImageProcessor
-from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
+from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -86,7 +86,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
@@ -137,7 +137,12 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
+class FluxPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+):
r"""
The Flux pipeline for text-to-image generation.
@@ -212,6 +217,9 @@ def _get_t5_prompt_embeds(
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
@@ -255,6 +263,9 @@ def _get_clip_prompt_embeds(
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
index cb47b9be95d0..efefb80bc9dc 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
@@ -25,7 +25,7 @@
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
+from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
@@ -106,7 +106,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
@@ -238,6 +238,9 @@ def _get_t5_prompt_embeds(
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
@@ -281,6 +284,9 @@ def _get_clip_prompt_embeds(
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
index a7f7c66a2cad..8d636feeae05 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
@@ -11,7 +11,7 @@
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
+from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
@@ -118,7 +118,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
@@ -251,6 +251,9 @@ def _get_t5_prompt_embeds(
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
@@ -295,6 +298,9 @@ def _get_clip_prompt_embeds(
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -897,9 +903,12 @@ def __call__(
timestep = t.expand(latents.shape[0]).to(latents.dtype)
- guidance = (
- torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
- )
+ if isinstance(self.controlnet, FluxMultiControlNetModel):
+ use_guidance = self.controlnet.nets[0].config.guidance_embeds
+ else:
+ use_guidance = self.controlnet.config.guidance_embeds
+
+ guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
if isinstance(controlnet_keep[i], list):
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
index 50d2fcaa7fa5..46784f2d46d1 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
@@ -12,7 +12,7 @@
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
+from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
@@ -120,7 +120,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
@@ -261,6 +261,9 @@ def _get_t5_prompt_embeds(
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
@@ -305,6 +308,9 @@ def _get_clip_prompt_embeds(
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
index 6375b1a994a1..14b062180003 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
@@ -20,7 +20,7 @@
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FluxLoraLoaderMixin
+from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -108,7 +108,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
@@ -235,6 +235,9 @@ def _get_t5_prompt_embeds(
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
@@ -279,6 +282,9 @@ def _get_clip_prompt_embeds(
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
index b9114c923cba..2f0be67bf88c 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
@@ -21,7 +21,7 @@
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FluxLoraLoaderMixin
+from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -105,7 +105,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
@@ -239,6 +239,9 @@ def _get_t5_prompt_embeds(
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
@@ -283,6 +286,9 @@ def _get_clip_prompt_embeds(
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
text_inputs = self.tokenizer(
prompt,
padding="max_length",
diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
index 40f20fe55bfb..304d6ed2a6c9 100644
--- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
+++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
@@ -125,9 +125,21 @@ def get_resize_crop_region_for_grid(src, tgt_size):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors.py b/src/diffusers/pipelines/kolors/pipeline_kolors.py
index 6f25fd1ac0be..a55d6b97f151 100644
--- a/src/diffusers/pipelines/kolors/pipeline_kolors.py
+++ b/src/diffusers/pipelines/kolors/pipeline_kolors.py
@@ -70,7 +70,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
index 145daf0181fa..6433b2f1f7f1 100644
--- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
+++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
@@ -89,7 +89,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
index dd72d3c9e10e..e985648abace 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
@@ -66,7 +66,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
index 89cafc2877fe..d110cd464522 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
@@ -70,7 +70,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py
index 8b78a7e75681..b0aa298b34ef 100644
--- a/src/diffusers/pipelines/latte/pipeline_latte.py
+++ b/src/diffusers/pipelines/latte/pipeline_latte.py
@@ -76,7 +76,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
index 537b509e3720..83b3753c2c87 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
@@ -234,9 +234,21 @@ def __call__(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
index 995cc15f3f93..834445bfcd06 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
@@ -1643,9 +1643,21 @@ def invert(
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py
index 1ded6013fb35..9a4b37972379 100644
--- a/src/diffusers/pipelines/lumina/pipeline_lumina.py
+++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py
@@ -73,7 +73,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
index 92a648a0d194..f3e4e94c836d 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
@@ -104,7 +104,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
index debcdad85656..6ccf11a3b774 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
@@ -126,7 +126,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
index ebcff0ea269a..7a88451273d4 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
@@ -128,9 +128,21 @@ def get_resize_crop_region_for_grid(src, tgt_size):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
index 4dc01b7caa64..7c9b6e1e065d 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
@@ -75,7 +75,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
index 8e5e6cbaf5ad..59d6a9001e1f 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
@@ -81,7 +81,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_sd.py
index 6a3efb71f8e0..0b0bd5e71aa3 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd.py
@@ -60,9 +60,21 @@
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -82,7 +94,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
index 2f0d997359a1..e5bb01dd2409 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
@@ -82,7 +82,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
index e4f26494d5c3..49dc4948cb40 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
@@ -89,7 +89,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
index 904942528a36..d850df080fb5 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
@@ -88,9 +88,21 @@
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -110,7 +122,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
index 2b17537dcfc6..74b7ede6a8d6 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
@@ -92,9 +92,21 @@
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -128,7 +140,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
index 7c2ec2e726ad..e36c1dc77396 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
@@ -105,9 +105,21 @@
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -141,7 +153,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py
index 0a744264b7a6..5eba1952e608 100644
--- a/src/diffusers/pipelines/pipeline_loading_utils.py
+++ b/src/diffusers/pipelines/pipeline_loading_utils.py
@@ -118,6 +118,10 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
components.setdefault(component, [])
components[component].append(component_filename)
+ # If there are no component folders check the main directory for safetensors files
+ if not components:
+ return any(".safetensors" in filename for filename in filenames)
+
# iterate over all files of a component
# check if safetensor files exist for that component
# if variant is provided check if the variant of the safetensors exists
@@ -838,3 +842,108 @@ def get_connected_passed_kwargs(prefix):
)
return init_kwargs
+
+
+def _get_custom_components_and_folders(
+ pretrained_model_name: str,
+ config_dict: Dict[str, Any],
+ filenames: Optional[List[str]] = None,
+ variant_filenames: Optional[List[str]] = None,
+ variant: Optional[str] = None,
+):
+ config_dict = config_dict.copy()
+
+ # retrieve all folder_names that contain relevant files
+ folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
+
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
+ pipelines = getattr(diffusers_module, "pipelines")
+
+ # optionally create a custom component <> custom file mapping
+ custom_components = {}
+ for component in folder_names:
+ module_candidate = config_dict[component][0]
+
+ if module_candidate is None or not isinstance(module_candidate, str):
+ continue
+
+ # We compute candidate file path on the Hub. Do not use `os.path.join`.
+ candidate_file = f"{component}/{module_candidate}.py"
+
+ if candidate_file in filenames:
+ custom_components[component] = module_candidate
+ elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
+ raise ValueError(
+ f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
+ )
+
+ if len(variant_filenames) == 0 and variant is not None:
+ error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
+ raise ValueError(error_message)
+
+ return custom_components, folder_names
+
+
+def _get_ignore_patterns(
+ passed_components,
+ model_folder_names: List[str],
+ model_filenames: List[str],
+ variant_filenames: List[str],
+ use_safetensors: bool,
+ from_flax: bool,
+ allow_pickle: bool,
+ use_onnx: bool,
+ is_onnx: bool,
+ variant: Optional[str] = None,
+) -> List[str]:
+ if (
+ use_safetensors
+ and not allow_pickle
+ and not is_safetensors_compatible(
+ model_filenames, passed_components=passed_components, folder_names=model_folder_names
+ )
+ ):
+ raise EnvironmentError(
+ f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
+ )
+
+ if from_flax:
+ ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
+
+ elif use_safetensors and is_safetensors_compatible(
+ model_filenames, passed_components=passed_components, folder_names=model_folder_names
+ ):
+ ignore_patterns = ["*.bin", "*.msgpack"]
+
+ use_onnx = use_onnx if use_onnx is not None else is_onnx
+ if not use_onnx:
+ ignore_patterns += ["*.onnx", "*.pb"]
+
+ safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
+ safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
+ if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames:
+ logger.warning(
+ f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
+ f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
+ f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not "
+ f"expected, please check your folder structure."
+ )
+
+ else:
+ ignore_patterns = ["*.safetensors", "*.msgpack"]
+
+ use_onnx = use_onnx if use_onnx is not None else is_onnx
+ if not use_onnx:
+ ignore_patterns += ["*.onnx", "*.pb"]
+
+ bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
+ bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
+ if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
+ logger.warning(
+ f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n"
+ f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n"
+ f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check "
+ f"your folder structure."
+ )
+
+ return ignore_patterns
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index 857a13147cfe..2e1858b16148 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -44,6 +44,7 @@
from ..models import AutoencoderKL
from ..models.attention_processor import FusedAttnProcessor2_0
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
+from ..quantizers.bitsandbytes.utils import _check_bnb_status
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import (
CONFIG_NAME,
@@ -54,6 +55,7 @@
is_accelerate_version,
is_torch_npu_available,
is_torch_version,
+ is_transformers_version,
logging,
numpy_to_pil,
)
@@ -71,15 +73,16 @@
CUSTOM_PIPELINE_FILE_NAME,
LOADABLE_CLASSES,
_fetch_class_library_tuple,
+ _get_custom_components_and_folders,
_get_custom_pipeline_class,
_get_final_device_map,
+ _get_ignore_patterns,
_get_pipeline_class,
_identify_model_variants,
_maybe_raise_warning_for_inpainting,
_resolve_custom_pipeline_and_cls,
_unwrap_model,
_update_init_kwargs_with_connected_pipeline,
- is_safetensors_compatible,
load_sub_model,
maybe_raise_or_warn,
variant_compatible_siblings,
@@ -431,18 +434,23 @@ def module_is_offloaded(module):
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
for module in modules:
- is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit
+ _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
- if is_loaded_in_8bit and dtype is not None:
+ if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
logger.warning(
- f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision."
+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
)
- if is_loaded_in_8bit and device is not None:
+ if is_loaded_in_8bit_bnb and device is not None:
logger.warning(
- f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
)
- else:
+
+ # This can happen for `transformer` models. CPU placement was added in
+ # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
+ if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
+ module.to(device=device)
+ elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
module.to(device, dtype)
if (
@@ -1039,9 +1047,18 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
hook = None
for model_str in self.model_cpu_offload_seq.split("->"):
model = all_model_components.pop(model_str, None)
+
if not isinstance(model, torch.nn.Module):
continue
+ # This is because the model would already be placed on a CUDA device.
+ _, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
+ if is_loaded_in_8bit_bnb:
+ logger.info(
+ f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
+ )
+ continue
+
_, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
self._all_hooks.append(hook)
@@ -1298,44 +1315,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
config_dict = cls._dict_from_json_file(config_file)
ignore_filenames = config_dict.pop("_ignore_files", [])
- # retrieve all folder_names that contain relevant files
- folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
-
- diffusers_module = importlib.import_module(__name__.split(".")[0])
- pipelines = getattr(diffusers_module, "pipelines")
-
- # optionally create a custom component <> custom file mapping
- custom_components = {}
- for component in folder_names:
- module_candidate = config_dict[component][0]
-
- if module_candidate is None or not isinstance(module_candidate, str):
- continue
-
- # We compute candidate file path on the Hub. Do not use `os.path.join`.
- candidate_file = f"{component}/{module_candidate}.py"
-
- if candidate_file in filenames:
- custom_components[component] = module_candidate
- elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
- raise ValueError(
- f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
- )
-
- if len(variant_filenames) == 0 and variant is not None:
- error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
- raise ValueError(error_message)
-
# remove ignored filenames
model_filenames = set(model_filenames) - set(ignore_filenames)
variant_filenames = set(variant_filenames) - set(ignore_filenames)
- # if the whole pipeline is cached we don't have to ping the Hub
if revision in DEPRECATED_REVISION_ARGS and version.parse(
version.parse(__version__).base_version
) >= version.parse("0.22.0"):
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
+ custom_components, folder_names = _get_custom_components_and_folders(
+ pretrained_model_name, config_dict, filenames, variant_filenames, variant
+ )
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
custom_class_name = None
@@ -1395,49 +1386,19 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
expected_components, _ = cls._get_signature_keys(pipeline_class)
passed_components = [k for k in expected_components if k in kwargs]
- if (
- use_safetensors
- and not allow_pickle
- and not is_safetensors_compatible(
- model_filenames, passed_components=passed_components, folder_names=model_folder_names
- )
- ):
- raise EnvironmentError(
- f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
- )
- if from_flax:
- ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
- elif use_safetensors and is_safetensors_compatible(
- model_filenames, passed_components=passed_components, folder_names=model_folder_names
- ):
- ignore_patterns = ["*.bin", "*.msgpack"]
-
- use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
- if not use_onnx:
- ignore_patterns += ["*.onnx", "*.pb"]
-
- safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
- safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
- if (
- len(safetensors_variant_filenames) > 0
- and safetensors_model_filenames != safetensors_variant_filenames
- ):
- logger.warning(
- f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
- )
- else:
- ignore_patterns = ["*.safetensors", "*.msgpack"]
-
- use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
- if not use_onnx:
- ignore_patterns += ["*.onnx", "*.pb"]
-
- bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
- bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
- if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
- logger.warning(
- f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
- )
+ # retrieve all patterns that should not be downloaded and error out when needed
+ ignore_patterns = _get_ignore_patterns(
+ passed_components,
+ model_folder_names,
+ model_filenames,
+ variant_filenames,
+ use_safetensors,
+ from_flax,
+ allow_pickle,
+ use_onnx,
+ pipeline_class._is_onnx,
+ variant,
+ )
# Don't download any objects that are passed
allow_patterns = [
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
index 5b220df8058b..46d8ad5e6dfa 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
@@ -178,7 +178,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
index 69f028914774..b2772d552514 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
@@ -122,7 +122,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index 5b5ff23df18e..092785017965 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -69,9 +69,21 @@
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -90,7 +102,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
@@ -149,7 +161,7 @@ class StableDiffusionPipeline(
IPAdapterMixin,
FromSingleFileMixin,
):
- r"""
+ """
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
@@ -1066,7 +1078,6 @@ def __call__(
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
-
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload all models
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
index 0b368ba2ac13..9c90b910e007 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -119,7 +119,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index 21c6642b5e9f..4cfe95d0d632 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -60,7 +60,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
index f76230b1eeac..e58733ad5397 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
@@ -77,7 +77,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
index 6a82ed9f48f5..cef9804e821d 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
@@ -98,7 +98,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
index 440b6529c9ca..7401be39d6f9 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
@@ -97,7 +97,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
index 95532efc6d2c..50d25c08b2f5 100644
--- a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
+++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
@@ -61,9 +61,21 @@
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -83,7 +95,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
index f6fabeeca011..9337c874faeb 100644
--- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
+++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
@@ -61,9 +61,21 @@
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -83,7 +95,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
index c99ce7ef03eb..3f8c7785957e 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
@@ -87,9 +87,21 @@
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -109,7 +121,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
index ddb31f28c472..8b53841cb138 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
@@ -90,9 +90,21 @@
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -126,7 +138,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index 199f5cabaa32..6a5fb284c39d 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -101,9 +101,21 @@
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -153,7 +165,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
index dc283108d143..fe3c9afbba1d 100644
--- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
+++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
@@ -71,7 +71,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
index 3cb7c26bb6a2..1a938aaf9423 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
@@ -127,7 +127,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
index 0ea197e42e62..20569d0adb32 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
@@ -119,9 +119,21 @@ def _preprocess_adapter_image(image, height, width):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -141,7 +153,7 @@ def retrieve_timesteps(
sigmas: Optional[List[float]] = None,
**kwargs,
):
- """
+ r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
index c46a5ce7c084..9ff473cc3a38 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
@@ -310,9 +310,21 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
- """
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py
new file mode 100644
index 000000000000..4c8483a3d6ee
--- /dev/null
+++ b/src/diffusers/quantizers/__init__.py
@@ -0,0 +1,16 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .auto import DiffusersAutoQuantizer
+from .base import DiffusersQuantizer
diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py
new file mode 100644
index 000000000000..97cbcdc0e53f
--- /dev/null
+++ b/src/diffusers/quantizers/auto.py
@@ -0,0 +1,126 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Adapted from
+https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py
+"""
+import warnings
+from typing import Dict, Optional, Union
+
+from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
+from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod
+
+
+AUTO_QUANTIZER_MAPPING = {
+ "bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
+ "bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
+}
+
+AUTO_QUANTIZATION_CONFIG_MAPPING = {
+ "bitsandbytes_4bit": BitsAndBytesConfig,
+ "bitsandbytes_8bit": BitsAndBytesConfig,
+}
+
+
+class DiffusersAutoQuantizer:
+ """
+ The auto diffusers quantizer class that takes care of automatically instantiating to the correct
+ `DiffusersQuantizer` given the `QuantizationConfig`.
+ """
+
+ @classmethod
+ def from_dict(cls, quantization_config_dict: Dict):
+ quant_method = quantization_config_dict.get("quant_method", None)
+ # We need a special care for bnb models to make sure everything is BC ..
+ if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
+ suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
+ quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
+ elif quant_method is None:
+ raise ValueError(
+ "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
+ )
+
+ if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
+ raise ValueError(
+ f"Unknown quantization type, got {quant_method} - supported types are:"
+ f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
+ )
+
+ target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
+ return target_cls.from_dict(quantization_config_dict)
+
+ @classmethod
+ def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs):
+ # Convert it to a QuantizationConfig if the q_config is a dict
+ if isinstance(quantization_config, dict):
+ quantization_config = cls.from_dict(quantization_config)
+
+ quant_method = quantization_config.quant_method
+
+ # Again, we need a special care for bnb as we have a single quantization config
+ # class for both 4-bit and 8-bit quantization
+ if quant_method == QuantizationMethod.BITS_AND_BYTES:
+ if quantization_config.load_in_8bit:
+ quant_method += "_8bit"
+ else:
+ quant_method += "_4bit"
+
+ if quant_method not in AUTO_QUANTIZER_MAPPING.keys():
+ raise ValueError(
+ f"Unknown quantization type, got {quant_method} - supported types are:"
+ f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
+ )
+
+ target_cls = AUTO_QUANTIZER_MAPPING[quant_method]
+ return target_cls(quantization_config, **kwargs)
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ model_config = cls.load_config(pretrained_model_name_or_path, **kwargs)
+ if getattr(model_config, "quantization_config", None) is None:
+ raise ValueError(
+ f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
+ )
+ quantization_config_dict = model_config.quantization_config
+ quantization_config = cls.from_dict(quantization_config_dict)
+ # Update with potential kwargs that are passed through from_pretrained.
+ quantization_config.update(kwargs)
+
+ return cls.from_config(quantization_config)
+
+ @classmethod
+ def merge_quantization_configs(
+ cls,
+ quantization_config: Union[dict, QuantizationConfigMixin],
+ quantization_config_from_args: Optional[QuantizationConfigMixin],
+ ):
+ """
+ handles situations where both quantization_config from args and quantization_config from model config are
+ present.
+ """
+ if quantization_config_from_args is not None:
+ warning_msg = (
+ "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading"
+ " already has a `quantization_config` attribute. The `quantization_config` from the model will be used."
+ )
+ else:
+ warning_msg = ""
+
+ if isinstance(quantization_config, dict):
+ quantization_config = cls.from_dict(quantization_config)
+
+ if warning_msg != "":
+ warnings.warn(warning_msg)
+
+ return quantization_config
diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py
new file mode 100644
index 000000000000..6ec3885fe373
--- /dev/null
+++ b/src/diffusers/quantizers/base.py
@@ -0,0 +1,233 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Adapted from
+https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/quantizers/base.py
+"""
+
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+from ..utils import is_torch_available
+from .quantization_config import QuantizationConfigMixin
+
+
+if TYPE_CHECKING:
+ from ..models.modeling_utils import ModelMixin
+
+if is_torch_available():
+ import torch
+
+
+class DiffusersQuantizer(ABC):
+ """
+ Abstract class of the HuggingFace quantizer. Supports for now quantizing HF diffusers models for inference and/or
+ quantization. This class is used only for diffusers.models.modeling_utils.ModelMixin.from_pretrained and cannot be
+ easily used outside the scope of that method yet.
+
+ Attributes
+ quantization_config (`diffusers.quantizers.quantization_config.QuantizationConfigMixin`):
+ The quantization config that defines the quantization parameters of your model that you want to quantize.
+ modules_to_not_convert (`List[str]`, *optional*):
+ The list of module names to not convert when quantizing the model.
+ required_packages (`List[str]`, *optional*):
+ The list of required pip packages to install prior to using the quantizer
+ requires_calibration (`bool`):
+ Whether the quantization method requires to calibrate the model before using it.
+ """
+
+ requires_calibration = False
+ required_packages = None
+
+ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
+ self.quantization_config = quantization_config
+
+ # -- Handle extra kwargs below --
+ self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", [])
+ self.pre_quantized = kwargs.pop("pre_quantized", True)
+
+ if not self.pre_quantized and self.requires_calibration:
+ raise ValueError(
+ f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized."
+ f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to "
+ f"pass `pre_quantized=True` while knowing what you are doing."
+ )
+
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ """
+ Some quantization methods require to explicitly set the dtype of the model to a target dtype. You need to
+ override this method in case you want to make sure that behavior is preserved
+
+ Args:
+ torch_dtype (`torch.dtype`):
+ The input dtype that is passed in `from_pretrained`
+ """
+ return torch_dtype
+
+ def update_device_map(self, device_map: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
+ """
+ Override this method if you want to pass a override the existing device map with a new one. E.g. for
+ bitsandbytes, since `accelerate` is a hard requirement, if no device_map is passed, the device_map is set to
+ `"auto"``
+
+ Args:
+ device_map (`Union[dict, str]`, *optional*):
+ The device_map that is passed through the `from_pretrained` method.
+ """
+ return device_map
+
+ def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ """
+ Override this method if you want to adjust the `target_dtype` variable used in `from_pretrained` to compute the
+ device_map in case the device_map is a `str`. E.g. for bitsandbytes we force-set `target_dtype` to `torch.int8`
+ and for 4-bit we pass a custom enum `accelerate.CustomDtype.int4`.
+
+ Args:
+ torch_dtype (`torch.dtype`, *optional*):
+ The torch_dtype that is used to compute the device_map.
+ """
+ return torch_dtype
+
+ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
+ """
+ Override this method if you want to adjust the `missing_keys`.
+
+ Args:
+ missing_keys (`List[str]`, *optional*):
+ The list of missing keys in the checkpoint compared to the state dict of the model
+ """
+ return missing_keys
+
+ def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
+ """
+ returns dtypes for modules that are not quantized - used for the computation of the device_map in case one
+ passes a str as a device_map. The method will use the `modules_to_not_convert` that is modified in
+ `_process_model_before_weight_loading`. `diffusers` models don't have any `modules_to_not_convert` attributes
+ yet but this can change soon in the future.
+
+ Args:
+ model (`~diffusers.models.modeling_utils.ModelMixin`):
+ The model to quantize
+ torch_dtype (`torch.dtype`):
+ The dtype passed in `from_pretrained` method.
+ """
+
+ return {
+ name: torch_dtype
+ for name, _ in model.named_parameters()
+ if any(m in name for m in self.modules_to_not_convert)
+ }
+
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
+ """adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
+ return max_memory
+
+ def check_if_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ state_dict: Dict[str, Any],
+ **kwargs,
+ ) -> bool:
+ """
+ checks if a loaded state_dict component is part of quantized param + some validation; only defined for
+ quantization methods that require to create a new parameters for quantization.
+ """
+ return False
+
+ def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter":
+ """
+ takes needed components from state_dict and creates quantized param.
+ """
+ return
+
+ def check_quantized_param_shape(self, *args, **kwargs):
+ """
+ checks if the quantized param has expected shape.
+ """
+ return True
+
+ def validate_environment(self, *args, **kwargs):
+ """
+ This method is used to potentially check for potential conflicts with arguments that are passed in
+ `from_pretrained`. You need to define it for all future quantizers that are integrated with diffusers. If no
+ explicit check are needed, simply return nothing.
+ """
+ return
+
+ def preprocess_model(self, model: "ModelMixin", **kwargs):
+ """
+ Setting model attributes and/or converting model before weights loading. At this point the model should be
+ initialized on the meta device so you can freely manipulate the skeleton of the model in order to replace
+ modules in-place. Make sure to override the abstract method `_process_model_before_weight_loading`.
+
+ Args:
+ model (`~diffusers.models.modeling_utils.ModelMixin`):
+ The model to quantize
+ kwargs (`dict`, *optional*):
+ The keyword arguments that are passed along `_process_model_before_weight_loading`.
+ """
+ model.is_quantized = True
+ model.quantization_method = self.quantization_config.quant_method
+ return self._process_model_before_weight_loading(model, **kwargs)
+
+ def postprocess_model(self, model: "ModelMixin", **kwargs):
+ """
+ Post-process the model post weights loading. Make sure to override the abstract method
+ `_process_model_after_weight_loading`.
+
+ Args:
+ model (`~diffusers.models.modeling_utils.ModelMixin`):
+ The model to quantize
+ kwargs (`dict`, *optional*):
+ The keyword arguments that are passed along `_process_model_after_weight_loading`.
+ """
+ return self._process_model_after_weight_loading(model, **kwargs)
+
+ def dequantize(self, model):
+ """
+ Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance. Note
+ not all quantization schemes support this.
+ """
+ model = self._dequantize(model)
+
+ # Delete quantizer and quantization config
+ del model.hf_quantizer
+
+ return model
+
+ def _dequantize(self, model):
+ raise NotImplementedError(
+ f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
+ )
+
+ @abstractmethod
+ def _process_model_before_weight_loading(self, model, **kwargs):
+ ...
+
+ @abstractmethod
+ def _process_model_after_weight_loading(self, model, **kwargs):
+ ...
+
+ @property
+ @abstractmethod
+ def is_serializable(self):
+ ...
+
+ @property
+ @abstractmethod
+ def is_trainable(self):
+ ...
diff --git a/src/diffusers/quantizers/bitsandbytes/__init__.py b/src/diffusers/quantizers/bitsandbytes/__init__.py
new file mode 100644
index 000000000000..9e745bc810fa
--- /dev/null
+++ b/src/diffusers/quantizers/bitsandbytes/__init__.py
@@ -0,0 +1,2 @@
+from .bnb_quantizer import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
+from .utils import dequantize_and_replace, dequantize_bnb_weight, replace_with_bnb_linear
diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
new file mode 100644
index 000000000000..d5ac1611a571
--- /dev/null
+++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
@@ -0,0 +1,558 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Adapted from
+https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/quantizer_bnb_4bit.py
+"""
+
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
+from ...utils import get_module_from_name
+from ..base import DiffusersQuantizer
+
+
+if TYPE_CHECKING:
+ from ...models.modeling_utils import ModelMixin
+
+from ...utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ is_bitsandbytes_available,
+ is_bitsandbytes_version,
+ is_torch_available,
+ logging,
+)
+
+
+if is_torch_available():
+ import torch
+
+logger = logging.get_logger(__name__)
+
+
+class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
+ """
+ 4-bit quantization from bitsandbytes.py quantization method:
+ before loading: converts transformer layers into Linear4bit during loading: load 16bit weight and pass to the
+ layer object after: quantizes individual weights in Linear4bit into 4bit at the first .cuda() call saving:
+ from state dict, as usual; saves weights and `quant_state` components
+ loading:
+ need to locate `quant_state` components and pass to Param4bit constructor
+ """
+
+ use_keep_in_fp32_modules = True
+ requires_calibration = False
+
+ def __init__(self, quantization_config, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+
+ if self.quantization_config.llm_int8_skip_modules is not None:
+ self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
+
+ def validate_environment(self, *args, **kwargs):
+ if not torch.cuda.is_available():
+ raise RuntimeError("No GPU found. A GPU is needed for quantization.")
+ if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
+ raise ImportError(
+ "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`"
+ )
+ if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"):
+ raise ImportError(
+ "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
+ )
+
+ if kwargs.get("from_flax", False):
+ raise ValueError(
+ "Converting into 4-bit weights from flax weights is currently not supported, please make"
+ " sure the weights are in PyTorch format."
+ )
+
+ device_map = kwargs.get("device_map", None)
+ if (
+ device_map is not None
+ and isinstance(device_map, dict)
+ and not self.quantization_config.llm_int8_enable_fp32_cpu_offload
+ ):
+ device_map_without_no_convert = {
+ key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert
+ }
+ if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
+ raise ValueError(
+ "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
+ "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
+ "in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to "
+ "`from_pretrained`. Check "
+ "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu "
+ "for more details. "
+ )
+
+ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
+ if target_dtype != torch.int8:
+ from accelerate.utils import CustomDtype
+
+ logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization")
+ return CustomDtype.INT4
+ else:
+ raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.")
+
+ def check_if_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ state_dict: Dict[str, Any],
+ **kwargs,
+ ) -> bool:
+ import bitsandbytes as bnb
+
+ module, tensor_name = get_module_from_name(model, param_name)
+ if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit):
+ # Add here check for loaded components' dtypes once serialization is implemented
+ return True
+ elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias":
+ # bias could be loaded by regular set_module_tensor_to_device() from accelerate,
+ # but it would wrongly use uninitialized weight there.
+ return True
+ else:
+ return False
+
+ def create_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ target_device: "torch.device",
+ state_dict: Dict[str, Any],
+ unexpected_keys: Optional[List[str]] = None,
+ ):
+ import bitsandbytes as bnb
+
+ module, tensor_name = get_module_from_name(model, param_name)
+
+ if tensor_name not in module._parameters:
+ raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
+
+ old_value = getattr(module, tensor_name)
+
+ if tensor_name == "bias":
+ if param_value is None:
+ new_value = old_value.to(target_device)
+ else:
+ new_value = param_value.to(target_device)
+
+ new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad)
+ module._parameters[tensor_name] = new_value
+ return
+
+ if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit):
+ raise ValueError("this function only loads `Linear4bit components`")
+ if (
+ old_value.device == torch.device("meta")
+ and target_device not in ["meta", torch.device("meta")]
+ and param_value is None
+ ):
+ raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
+
+ # construct `new_value` for the module._parameters[tensor_name]:
+ if self.pre_quantized:
+ # 4bit loading. Collecting components for restoring quantized weight
+ # This can be expanded to make a universal call for any quantized weight loading
+
+ if not self.is_serializable:
+ raise ValueError(
+ "Detected int4 weights but the version of bitsandbytes is not compatible with int4 serialization. "
+ "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
+ )
+
+ if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and (
+ param_name + ".quant_state.bitsandbytes__nf4" not in state_dict
+ ):
+ raise ValueError(
+ f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components."
+ )
+
+ quantized_stats = {}
+ for k, v in state_dict.items():
+ # `startswith` to counter for edge cases where `param_name`
+ # substring can be present in multiple places in the `state_dict`
+ if param_name + "." in k and k.startswith(param_name):
+ quantized_stats[k] = v
+ if unexpected_keys is not None and k in unexpected_keys:
+ unexpected_keys.remove(k)
+
+ new_value = bnb.nn.Params4bit.from_prequantized(
+ data=param_value,
+ quantized_stats=quantized_stats,
+ requires_grad=False,
+ device=target_device,
+ )
+ else:
+ new_value = param_value.to("cpu")
+ kwargs = old_value.__dict__
+ new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device)
+
+ module._parameters[tensor_name] = new_value
+
+ def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape):
+ n = current_param_shape.numel()
+ inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
+ if loaded_param_shape != inferred_shape:
+ raise ValueError(
+ f"Expected the flattened shape of the current param ({param_name}) to be {loaded_param_shape} but is {inferred_shape}."
+ )
+ else:
+ return True
+
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
+ # need more space for buffers that are created during quantization
+ max_memory = {key: val * 0.90 for key, val in max_memory.items()}
+ return max_memory
+
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ if torch_dtype is None:
+ # We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
+ logger.info(
+ "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to "
+ "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
+ "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
+ " torch_dtype=torch.float16 to remove this warning.",
+ torch_dtype,
+ )
+ torch_dtype = torch.float16
+ return torch_dtype
+
+ # (sayakpaul): I think it could be better to disable custom `device_map`s
+ # for the first phase of the integration in the interest of simplicity.
+ # Commenting this for discussions on the PR.
+ # def update_device_map(self, device_map):
+ # if device_map is None:
+ # device_map = {"": torch.cuda.current_device()}
+ # logger.info(
+ # "The device_map was not initialized. "
+ # "Setting device_map to {'':torch.cuda.current_device()}. "
+ # "If you want to use the model for inference, please set device_map ='auto' "
+ # )
+ # return device_map
+
+ def _process_model_before_weight_loading(
+ self,
+ model: "ModelMixin",
+ device_map,
+ keep_in_fp32_modules: List[str] = [],
+ **kwargs,
+ ):
+ from .utils import replace_with_bnb_linear
+
+ load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
+
+ # We may keep some modules such as the `proj_out` in their original dtype for numerical stability reasons
+ self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
+
+ if not isinstance(self.modules_to_not_convert, list):
+ self.modules_to_not_convert = [self.modules_to_not_convert]
+
+ self.modules_to_not_convert.extend(keep_in_fp32_modules)
+
+ # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
+ if isinstance(device_map, dict) and len(device_map.keys()) > 1:
+ keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
+
+ if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload:
+ raise ValueError(
+ "If you want to offload some keys to `cpu` or `disk`, you need to set "
+ "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be "
+ " converted to 8-bit but kept in 32-bit."
+ )
+ self.modules_to_not_convert.extend(keys_on_cpu)
+
+ # Purge `None`.
+ # Unlike `transformers`, we don't know if we should always keep certain modules in FP32
+ # in case of diffusion transformer models. For language models and others alike, `lm_head`
+ # and tied modules are usually kept in FP32.
+ self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
+
+ model = replace_with_bnb_linear(
+ model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
+ )
+ model.config.quantization_config = self.quantization_config
+
+ def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
+ model.is_loaded_in_4bit = True
+ model.is_4bit_serializable = self.is_serializable
+ return model
+
+ @property
+ def is_serializable(self):
+ # Because we're mandating `bitsandbytes` 0.43.3.
+ return True
+
+ @property
+ def is_trainable(self) -> bool:
+ # Because we're mandating `bitsandbytes` 0.43.3.
+ return True
+
+ def _dequantize(self, model):
+ from .utils import dequantize_and_replace
+
+ is_model_on_cpu = model.device.type == "cpu"
+ if is_model_on_cpu:
+ logger.info(
+ "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
+ )
+ model.to(torch.cuda.current_device())
+
+ model = dequantize_and_replace(
+ model, self.modules_to_not_convert, quantization_config=self.quantization_config
+ )
+ if is_model_on_cpu:
+ model.to("cpu")
+ return model
+
+
+class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
+ """
+ 8-bit quantization from bitsandbytes quantization method:
+ before loading: converts transformer layers into Linear8bitLt during loading: load 16bit weight and pass to the
+ layer object after: quantizes individual weights in Linear8bitLt into 8bit at fitst .cuda() call
+ saving:
+ from state dict, as usual; saves weights and 'SCB' component
+ loading:
+ need to locate SCB component and pass to the Linear8bitLt object
+ """
+
+ use_keep_in_fp32_modules = True
+ requires_calibration = False
+
+ def __init__(self, quantization_config, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+
+ if self.quantization_config.llm_int8_skip_modules is not None:
+ self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
+
+ def validate_environment(self, *args, **kwargs):
+ if not torch.cuda.is_available():
+ raise RuntimeError("No GPU found. A GPU is needed for quantization.")
+ if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
+ raise ImportError(
+ "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`"
+ )
+ if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"):
+ raise ImportError(
+ "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
+ )
+
+ if kwargs.get("from_flax", False):
+ raise ValueError(
+ "Converting into 8-bit weights from flax weights is currently not supported, please make"
+ " sure the weights are in PyTorch format."
+ )
+
+ device_map = kwargs.get("device_map", None)
+ if (
+ device_map is not None
+ and isinstance(device_map, dict)
+ and not self.quantization_config.llm_int8_enable_fp32_cpu_offload
+ ):
+ device_map_without_no_convert = {
+ key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert
+ }
+ if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
+ raise ValueError(
+ "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
+ "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
+ "in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to "
+ "`from_pretrained`. Check "
+ "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu "
+ "for more details. "
+ )
+
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
+ # need more space for buffers that are created during quantization
+ max_memory = {key: val * 0.90 for key, val in max_memory.items()}
+ return max_memory
+
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_torch_dtype
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ if torch_dtype is None:
+ # We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
+ logger.info(
+ "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to "
+ "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. "
+ "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
+ " torch_dtype=torch.float16 to remove this warning.",
+ torch_dtype,
+ )
+ torch_dtype = torch.float16
+ return torch_dtype
+
+ # # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
+ # def update_device_map(self, device_map):
+ # if device_map is None:
+ # device_map = {"": torch.cuda.current_device()}
+ # logger.info(
+ # "The device_map was not initialized. "
+ # "Setting device_map to {'':torch.cuda.current_device()}. "
+ # "If you want to use the model for inference, please set device_map ='auto' "
+ # )
+ # return device_map
+
+ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
+ if target_dtype != torch.int8:
+ logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization")
+ return torch.int8
+
+ def check_if_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ state_dict: Dict[str, Any],
+ **kwargs,
+ ):
+ import bitsandbytes as bnb
+
+ module, tensor_name = get_module_from_name(model, param_name)
+ if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Int8Params):
+ if self.pre_quantized:
+ if param_name.replace("weight", "SCB") not in state_dict.keys():
+ raise ValueError("Missing quantization component `SCB`")
+ if param_value.dtype != torch.int8:
+ raise ValueError(
+ f"Incompatible dtype `{param_value.dtype}` when loading 8-bit prequantized weight. Expected `torch.int8`."
+ )
+ return True
+ return False
+
+ def create_quantized_param(
+ self,
+ model: "ModelMixin",
+ param_value: "torch.Tensor",
+ param_name: str,
+ target_device: "torch.device",
+ state_dict: Dict[str, Any],
+ unexpected_keys: Optional[List[str]] = None,
+ ):
+ import bitsandbytes as bnb
+
+ fp16_statistics_key = param_name.replace("weight", "SCB")
+ fp16_weights_format_key = param_name.replace("weight", "weight_format")
+
+ fp16_statistics = state_dict.get(fp16_statistics_key, None)
+ fp16_weights_format = state_dict.get(fp16_weights_format_key, None)
+
+ module, tensor_name = get_module_from_name(model, param_name)
+ if tensor_name not in module._parameters:
+ raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
+
+ old_value = getattr(module, tensor_name)
+
+ if not isinstance(module._parameters[tensor_name], bnb.nn.Int8Params):
+ raise ValueError(f"Parameter `{tensor_name}` should only be a `bnb.nn.Int8Params` instance.")
+ if (
+ old_value.device == torch.device("meta")
+ and target_device not in ["meta", torch.device("meta")]
+ and param_value is None
+ ):
+ raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
+
+ new_value = param_value.to("cpu")
+ if self.pre_quantized and not self.is_serializable:
+ raise ValueError(
+ "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
+ "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
+ )
+
+ kwargs = old_value.__dict__
+ new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(target_device)
+
+ module._parameters[tensor_name] = new_value
+ if fp16_statistics is not None:
+ setattr(module.weight, "SCB", fp16_statistics.to(target_device))
+ if unexpected_keys is not None:
+ unexpected_keys.remove(fp16_statistics_key)
+
+ # We just need to pop the `weight_format` keys from the state dict to remove unneeded
+ # messages. The correct format is correctly retrieved during the first forward pass.
+ if fp16_weights_format is not None and unexpected_keys is not None:
+ unexpected_keys.remove(fp16_weights_format_key)
+
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit
+ def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
+ model.is_loaded_in_8bit = True
+ model.is_8bit_serializable = self.is_serializable
+ return model
+
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading
+ def _process_model_before_weight_loading(
+ self,
+ model: "ModelMixin",
+ device_map,
+ keep_in_fp32_modules: List[str] = [],
+ **kwargs,
+ ):
+ from .utils import replace_with_bnb_linear
+
+ load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload
+
+ # We may keep some modules such as the `proj_out` in their original dtype for numerical stability reasons
+ self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
+
+ if not isinstance(self.modules_to_not_convert, list):
+ self.modules_to_not_convert = [self.modules_to_not_convert]
+
+ self.modules_to_not_convert.extend(keep_in_fp32_modules)
+
+ # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk`
+ if isinstance(device_map, dict) and len(device_map.keys()) > 1:
+ keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]]
+
+ if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload:
+ raise ValueError(
+ "If you want to offload some keys to `cpu` or `disk`, you need to set "
+ "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be "
+ " converted to 8-bit but kept in 32-bit."
+ )
+ self.modules_to_not_convert.extend(keys_on_cpu)
+
+ # Purge `None`.
+ # Unlike `transformers`, we don't know if we should always keep certain modules in FP32
+ # in case of diffusion transformer models. For language models and others alike, `lm_head`
+ # and tied modules are usually kept in FP32.
+ self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
+
+ model = replace_with_bnb_linear(
+ model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
+ )
+ model.config.quantization_config = self.quantization_config
+
+ @property
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable
+ def is_serializable(self):
+ # Because we're mandating `bitsandbytes` 0.43.3.
+ return True
+
+ @property
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable
+ def is_trainable(self) -> bool:
+ # Because we're mandating `bitsandbytes` 0.43.3.
+ return True
+
+ def _dequantize(self, model):
+ from .utils import dequantize_and_replace
+
+ model = dequantize_and_replace(
+ model, self.modules_to_not_convert, quantization_config=self.quantization_config
+ )
+ return model
diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py
new file mode 100644
index 000000000000..03755db3d1ec
--- /dev/null
+++ b/src/diffusers/quantizers/bitsandbytes/utils.py
@@ -0,0 +1,306 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Adapted from
+https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py
+"""
+
+import inspect
+from inspect import signature
+from typing import Union
+
+from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging
+from ..quantization_config import QuantizationMethod
+
+
+if is_torch_available():
+ import torch
+ import torch.nn as nn
+
+if is_bitsandbytes_available():
+ import bitsandbytes as bnb
+
+if is_accelerate_available():
+ import accelerate
+ from accelerate import init_empty_weights
+ from accelerate.hooks import add_hook_to_module, remove_hook_from_module
+
+logger = logging.get_logger(__name__)
+
+
+def _replace_with_bnb_linear(
+ model,
+ modules_to_not_convert=None,
+ current_key_name=None,
+ quantization_config=None,
+ has_been_replaced=False,
+):
+ """
+ Private method that wraps the recursion for module replacement.
+
+ Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
+ """
+ for name, module in model.named_children():
+ if current_key_name is None:
+ current_key_name = []
+ current_key_name.append(name)
+
+ if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
+ # Check if the current key is not in the `modules_to_not_convert`
+ current_key_name_str = ".".join(current_key_name)
+ if not any(
+ (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
+ ):
+ with init_empty_weights():
+ in_features = module.in_features
+ out_features = module.out_features
+
+ if quantization_config.quantization_method() == "llm_int8":
+ model._modules[name] = bnb.nn.Linear8bitLt(
+ in_features,
+ out_features,
+ module.bias is not None,
+ has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
+ threshold=quantization_config.llm_int8_threshold,
+ )
+ has_been_replaced = True
+ else:
+ if (
+ quantization_config.llm_int8_skip_modules is not None
+ and name in quantization_config.llm_int8_skip_modules
+ ):
+ pass
+ else:
+ extra_kwargs = (
+ {"quant_storage": quantization_config.bnb_4bit_quant_storage}
+ if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters)
+ else {}
+ )
+ model._modules[name] = bnb.nn.Linear4bit(
+ in_features,
+ out_features,
+ module.bias is not None,
+ quantization_config.bnb_4bit_compute_dtype,
+ compress_statistics=quantization_config.bnb_4bit_use_double_quant,
+ quant_type=quantization_config.bnb_4bit_quant_type,
+ **extra_kwargs,
+ )
+ has_been_replaced = True
+ # Store the module class in case we need to transpose the weight later
+ model._modules[name].source_cls = type(module)
+ # Force requires grad to False to avoid unexpected errors
+ model._modules[name].requires_grad_(False)
+ if len(list(module.children())) > 0:
+ _, has_been_replaced = _replace_with_bnb_linear(
+ module,
+ modules_to_not_convert,
+ current_key_name,
+ quantization_config,
+ has_been_replaced=has_been_replaced,
+ )
+ # Remove the last key for recursion
+ current_key_name.pop(-1)
+ return model, has_been_replaced
+
+
+def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
+ """
+ Helper function to replace the `nn.Linear` layers within `model` with either `bnb.nn.Linear8bit` or
+ `bnb.nn.Linear4bit` using the `bitsandbytes` library.
+
+ References:
+ * `bnb.nn.Linear8bit`: [LLM.int8(): 8-bit Matrix Multiplication for Transformers at
+ Scale](https://arxiv.org/abs/2208.07339)
+ * `bnb.nn.Linear4bit`: [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314)
+
+ Parameters:
+ model (`torch.nn.Module`):
+ Input model or `torch.nn.Module` as the function is run recursively.
+ modules_to_not_convert (`List[`str`]`, *optional*, defaults to `[]`):
+ Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `modules_to_not_convert` in
+ full precision for numerical stability reasons.
+ current_key_name (`List[`str`]`, *optional*):
+ An array to track the current key of the recursion. This is used to check whether the current key (part of
+ it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
+ `disk`).
+ quantization_config ('transformers.utils.quantization_config.BitsAndBytesConfig'):
+ To configure and manage settings related to quantization, a technique used to compress neural network
+ models by reducing the precision of the weights and activations, thus making models more efficient in terms
+ of both storage and computation.
+ """
+ model, has_been_replaced = _replace_with_bnb_linear(
+ model, modules_to_not_convert, current_key_name, quantization_config
+ )
+
+ if not has_been_replaced:
+ logger.warning(
+ "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
+ " Please double check your model architecture, or submit an issue on github if you think this is"
+ " a bug."
+ )
+
+ return model
+
+
+# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
+def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
+ """
+ Helper function to dequantize 4bit or 8bit bnb weights.
+
+ If the weight is not a bnb quantized weight, it will be returned as is.
+ """
+ if not isinstance(weight, torch.nn.Parameter):
+ raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")
+
+ cls_name = weight.__class__.__name__
+ if cls_name not in ("Params4bit", "Int8Params"):
+ return weight
+
+ if cls_name == "Params4bit":
+ output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
+ logger.warning_once(
+ f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
+ )
+ return output_tensor
+
+ if state.SCB is None:
+ state.SCB = weight.SCB
+
+ im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
+ im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
+ im, Sim = bnb.functional.transform(im, "col32")
+ if state.CxB is None:
+ state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
+ out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
+ return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
+
+
+def _create_accelerate_new_hook(old_hook):
+ r"""
+ Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of:
+ https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with
+ some changes
+ """
+ old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
+ old_hook_attr = old_hook.__dict__
+ filtered_old_hook_attr = {}
+ old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
+ for k in old_hook_attr.keys():
+ if k in old_hook_init_signature.parameters:
+ filtered_old_hook_attr[k] = old_hook_attr[k]
+ new_hook = old_hook_cls(**filtered_old_hook_attr)
+ return new_hook
+
+
+def _dequantize_and_replace(
+ model,
+ modules_to_not_convert=None,
+ current_key_name=None,
+ quantization_config=None,
+ has_been_replaced=False,
+):
+ """
+ Converts a quantized model into its dequantized original version. The newly converted model will have some
+ performance drop compared to the original model before quantization - use it only for specific usecases such as
+ QLoRA adapters merging.
+
+ Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
+ """
+ quant_method = quantization_config.quantization_method()
+
+ target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit
+
+ for name, module in model.named_children():
+ if current_key_name is None:
+ current_key_name = []
+ current_key_name.append(name)
+
+ if isinstance(module, target_cls) and name not in modules_to_not_convert:
+ # Check if the current key is not in the `modules_to_not_convert`
+ current_key_name_str = ".".join(current_key_name)
+
+ if not any(
+ (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
+ ):
+ bias = getattr(module, "bias", None)
+
+ device = module.weight.device
+ with init_empty_weights():
+ new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)
+
+ if quant_method == "llm_int8":
+ state = module.state
+ else:
+ state = None
+
+ new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
+
+ if bias is not None:
+ new_module.bias = bias
+
+ # Create a new hook and attach it in case we use accelerate
+ if hasattr(module, "_hf_hook"):
+ old_hook = module._hf_hook
+ new_hook = _create_accelerate_new_hook(old_hook)
+
+ remove_hook_from_module(module)
+ add_hook_to_module(new_module, new_hook)
+
+ new_module.to(device)
+ model._modules[name] = new_module
+ has_been_replaced = True
+ if len(list(module.children())) > 0:
+ _, has_been_replaced = _dequantize_and_replace(
+ module,
+ modules_to_not_convert,
+ current_key_name,
+ quantization_config,
+ has_been_replaced=has_been_replaced,
+ )
+ # Remove the last key for recursion
+ current_key_name.pop(-1)
+ return model, has_been_replaced
+
+
+def dequantize_and_replace(
+ model,
+ modules_to_not_convert=None,
+ quantization_config=None,
+):
+ model, has_been_replaced = _dequantize_and_replace(
+ model,
+ modules_to_not_convert=modules_to_not_convert,
+ quantization_config=quantization_config,
+ )
+
+ if not has_been_replaced:
+ logger.warning(
+ "For some reason the model has not been properly dequantized. You might see unexpected behavior."
+ )
+
+ return model
+
+
+def _check_bnb_status(module) -> Union[bool, bool]:
+ is_loaded_in_4bit_bnb = (
+ hasattr(module, "is_loaded_in_4bit")
+ and module.is_loaded_in_4bit
+ and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
+ )
+ is_loaded_in_8bit_bnb = (
+ hasattr(module, "is_loaded_in_8bit")
+ and module.is_loaded_in_8bit
+ and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
+ )
+ return is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb
diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py
new file mode 100644
index 000000000000..f521c5d717d6
--- /dev/null
+++ b/src/diffusers/quantizers/quantization_config.py
@@ -0,0 +1,391 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Adapted from
+https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/utils/quantization_config.py
+"""
+
+import copy
+import importlib.metadata
+import json
+import os
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Dict, Union
+
+from packaging import version
+
+from ..utils import is_torch_available, logging
+
+
+if is_torch_available():
+ import torch
+
+logger = logging.get_logger(__name__)
+
+
+class QuantizationMethod(str, Enum):
+ BITS_AND_BYTES = "bitsandbytes"
+
+
+@dataclass
+class QuantizationConfigMixin:
+ """
+ Mixin class for quantization config
+ """
+
+ quant_method: QuantizationMethod
+ _exclude_attributes_at_init = []
+
+ @classmethod
+ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
+ """
+ Instantiates a [`QuantizationConfigMixin`] from a Python dictionary of parameters.
+
+ Args:
+ config_dict (`Dict[str, Any]`):
+ Dictionary that will be used to instantiate the configuration object.
+ return_unused_kwargs (`bool`,*optional*, defaults to `False`):
+ Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
+ `PreTrainedModel`.
+ kwargs (`Dict[str, Any]`):
+ Additional parameters from which to initialize the configuration object.
+
+ Returns:
+ [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
+ """
+
+ config = cls(**config_dict)
+
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(config, key):
+ setattr(config, key, value)
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+
+ if return_unused_kwargs:
+ return config, kwargs
+ else:
+ return config
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save this instance to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file in which this configuration instance's parameters will be saved.
+ use_diff (`bool`, *optional*, defaults to `True`):
+ If set to `True`, only the difference between the config instance and the default
+ `QuantizationConfig()` is serialized to JSON file.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ config_dict = self.to_dict()
+ json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ writer.write(json_string)
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary. Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+ """
+ return copy.deepcopy(self.__dict__)
+
+ def __iter__(self):
+ """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
+ for attr, value in copy.deepcopy(self.__dict__).items():
+ yield attr, value
+
+ def __repr__(self):
+ return f"{self.__class__.__name__} {self.to_json_string()}"
+
+ def to_json_string(self, use_diff: bool = True) -> str:
+ """
+ Serializes this instance to a JSON string.
+
+ Args:
+ use_diff (`bool`, *optional*, defaults to `True`):
+ If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
+ is serialized to JSON string.
+
+ Returns:
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
+ """
+ if use_diff is True:
+ config_dict = self.to_diff_dict()
+ else:
+ config_dict = self.to_dict()
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ def update(self, **kwargs):
+ """
+ Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
+ returning all the unused kwargs.
+
+ Args:
+ kwargs (`Dict[str, Any]`):
+ Dictionary of attributes to tentatively update this class.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
+ """
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+ to_remove.append(key)
+
+ # Remove all the attributes that were updated, without modifying the input dict
+ unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
+ return unused_kwargs
+
+
+@dataclass
+class BitsAndBytesConfig(QuantizationConfigMixin):
+ """
+ This is a wrapper class about all possible attributes and features that you can play with a model that has been
+ loaded using `bitsandbytes`.
+
+ This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive.
+
+ Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`,
+ then more arguments will be added to this class.
+
+ Args:
+ load_in_8bit (`bool`, *optional*, defaults to `False`):
+ This flag is used to enable 8-bit quantization with LLM.int8().
+ load_in_4bit (`bool`, *optional*, defaults to `False`):
+ This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from
+ `bitsandbytes`.
+ llm_int8_threshold (`float`, *optional*, defaults to 6.0):
+ This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix
+ Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value
+ that is above this threshold will be considered an outlier and the operation on those values will be done
+ in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but
+ there are some exceptional systematic outliers that are very differently distributed for large models.
+ These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of
+ magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6,
+ but a lower threshold might be needed for more unstable models (small models, fine-tuning).
+ llm_int8_skip_modules (`List[str]`, *optional*):
+ An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as
+ Jukebox that has several heads in different places and not necessarily at the last position. For example
+ for `CausalLM` models, the last `lm_head` is typically kept in its original `dtype`.
+ llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`):
+ This flag is used for advanced use cases and users that are aware of this feature. If you want to split
+ your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use
+ this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8
+ operations will not be run on CPU.
+ llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`):
+ This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not
+ have to be converted back and forth for the backward pass.
+ bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`):
+ This sets the computational type which might be different than the input type. For example, inputs might be
+ fp32, but computation can be set to bf16 for speedups.
+ bnb_4bit_quant_type (`str`, *optional*, defaults to `"fp4"`):
+ This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types
+ which are specified by `fp4` or `nf4`.
+ bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`):
+ This flag is used for nested quantization where the quantization constants from the first quantization are
+ quantized again.
+ bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`):
+ This sets the storage type to pack the quanitzed 4-bit prarams.
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional parameters from which to initialize the configuration object.
+ """
+
+ _exclude_attributes_at_init = ["_load_in_4bit", "_load_in_8bit", "quant_method"]
+
+ def __init__(
+ self,
+ load_in_8bit=False,
+ load_in_4bit=False,
+ llm_int8_threshold=6.0,
+ llm_int8_skip_modules=None,
+ llm_int8_enable_fp32_cpu_offload=False,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=None,
+ bnb_4bit_quant_type="fp4",
+ bnb_4bit_use_double_quant=False,
+ bnb_4bit_quant_storage=None,
+ **kwargs,
+ ):
+ self.quant_method = QuantizationMethod.BITS_AND_BYTES
+
+ if load_in_4bit and load_in_8bit:
+ raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
+
+ self._load_in_8bit = load_in_8bit
+ self._load_in_4bit = load_in_4bit
+ self.llm_int8_threshold = llm_int8_threshold
+ self.llm_int8_skip_modules = llm_int8_skip_modules
+ self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
+ self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
+ self.bnb_4bit_quant_type = bnb_4bit_quant_type
+ self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
+
+ if bnb_4bit_compute_dtype is None:
+ self.bnb_4bit_compute_dtype = torch.float32
+ elif isinstance(bnb_4bit_compute_dtype, str):
+ self.bnb_4bit_compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
+ elif isinstance(bnb_4bit_compute_dtype, torch.dtype):
+ self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
+ else:
+ raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
+
+ if bnb_4bit_quant_storage is None:
+ self.bnb_4bit_quant_storage = torch.uint8
+ elif isinstance(bnb_4bit_quant_storage, str):
+ if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]:
+ raise ValueError(
+ "`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') "
+ )
+ self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage)
+ elif isinstance(bnb_4bit_quant_storage, torch.dtype):
+ self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
+ else:
+ raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype")
+
+ if kwargs and not all(k in self._exclude_attributes_at_init for k in kwargs):
+ logger.warning(f"Unused kwargs: {list(kwargs.keys())}. These kwargs are not used in {self.__class__}.")
+
+ self.post_init()
+
+ @property
+ def load_in_4bit(self):
+ return self._load_in_4bit
+
+ @load_in_4bit.setter
+ def load_in_4bit(self, value: bool):
+ if not isinstance(value, bool):
+ raise TypeError("load_in_4bit must be a boolean")
+
+ if self.load_in_8bit and value:
+ raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
+ self._load_in_4bit = value
+
+ @property
+ def load_in_8bit(self):
+ return self._load_in_8bit
+
+ @load_in_8bit.setter
+ def load_in_8bit(self, value: bool):
+ if not isinstance(value, bool):
+ raise TypeError("load_in_8bit must be a boolean")
+
+ if self.load_in_4bit and value:
+ raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time")
+ self._load_in_8bit = value
+
+ def post_init(self):
+ r"""
+ Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
+ """
+ if not isinstance(self.load_in_4bit, bool):
+ raise TypeError("load_in_4bit must be a boolean")
+
+ if not isinstance(self.load_in_8bit, bool):
+ raise TypeError("load_in_8bit must be a boolean")
+
+ if not isinstance(self.llm_int8_threshold, float):
+ raise TypeError("llm_int8_threshold must be a float")
+
+ if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list):
+ raise TypeError("llm_int8_skip_modules must be a list of strings")
+ if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool):
+ raise TypeError("llm_int8_enable_fp32_cpu_offload must be a boolean")
+
+ if not isinstance(self.llm_int8_has_fp16_weight, bool):
+ raise TypeError("llm_int8_has_fp16_weight must be a boolean")
+
+ if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
+ raise TypeError("bnb_4bit_compute_dtype must be torch.dtype")
+
+ if not isinstance(self.bnb_4bit_quant_type, str):
+ raise TypeError("bnb_4bit_quant_type must be a string")
+
+ if not isinstance(self.bnb_4bit_use_double_quant, bool):
+ raise TypeError("bnb_4bit_use_double_quant must be a boolean")
+
+ if self.load_in_4bit and not version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse(
+ "0.39.0"
+ ):
+ raise ValueError(
+ "4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version"
+ )
+
+ def is_quantizable(self):
+ r"""
+ Returns `True` if the model is quantizable, `False` otherwise.
+ """
+ return self.load_in_8bit or self.load_in_4bit
+
+ def quantization_method(self):
+ r"""
+ This method returns the quantization method used for the model. If the model is not quantizable, it returns
+ `None`.
+ """
+ if self.load_in_8bit:
+ return "llm_int8"
+ elif self.load_in_4bit and self.bnb_4bit_quant_type == "fp4":
+ return "fp4"
+ elif self.load_in_4bit and self.bnb_4bit_quant_type == "nf4":
+ return "nf4"
+ else:
+ return None
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary. Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1]
+ output["bnb_4bit_quant_storage"] = str(output["bnb_4bit_quant_storage"]).split(".")[1]
+ output["load_in_4bit"] = self.load_in_4bit
+ output["load_in_8bit"] = self.load_in_8bit
+
+ return output
+
+ def __repr__(self):
+ config_dict = self.to_dict()
+ return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
+
+ def to_diff_dict(self) -> Dict[str, Any]:
+ """
+ Removes all attributes from config which correspond to the default config attributes for better readability and
+ serializes to a Python dictionary.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ config_dict = self.to_dict()
+
+ # get the default config dict
+ default_config_dict = BitsAndBytesConfig().to_dict()
+
+ serializable_config_dict = {}
+
+ # only serialize values that differ from the default config
+ for key, value in config_dict.items():
+ if value != default_config_dict[key]:
+ serializable_config_dict[key] = value
+
+ return serializable_config_dict
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
index 3329919cfb02..6841a34a6489 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
@@ -330,6 +330,7 @@ def set_timesteps(
# Clipping the minimum of all lambda(t) for numerical stability.
# This is critical for cosine (squaredcos_cap_v2) noise schedule.
clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped)
+ clipped_idx = clipped_idx.item()
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1)
.round()[::-1][:-1]
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index c7ea2bcc5b7f..c8f64adf3e8a 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -62,6 +62,7 @@
is_accelerate_available,
is_accelerate_version,
is_bitsandbytes_available,
+ is_bitsandbytes_version,
is_bs4_available,
is_flax_available,
is_ftfy_available,
@@ -94,7 +95,7 @@
is_xformers_available,
requires_backends,
)
-from .loading_utils import load_image, load_video
+from .loading_utils import get_module_from_name, load_image, load_video
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index eaab67c93b18..10d0399a6761 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -1020,6 +1020,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class DiffusersQuantizer(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AmusedScheduler(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index d8109eee6d35..9046a4f73533 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -272,6 +272,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class CogVideoXFunControlPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class CogVideoXImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py
index 3ff859f17fe3..448e92509732 100644
--- a/src/diffusers/utils/hub_utils.py
+++ b/src/diffusers/utils/hub_utils.py
@@ -457,7 +457,7 @@ def _get_checkpoint_shard_files(
ignore_patterns = ["*.json", "*.md"]
if not local_files_only:
# `model_info` call must guarded with the above condition.
- model_files_info = model_info(pretrained_model_name_or_path, revision=revision)
+ model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
for shard_file in original_shard_filenames:
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
if not shard_file_present:
diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py
index daecec4aa258..f1323bf00ea4 100644
--- a/src/diffusers/utils/import_utils.py
+++ b/src/diffusers/utils/import_utils.py
@@ -745,6 +745,20 @@ def is_peft_version(operation: str, version: str):
return compare_versions(parse(_peft_version), operation, version)
+def is_bitsandbytes_version(operation: str, version: str):
+ """
+ Args:
+ Compares the current bitsandbytes version to a given reference with an operation.
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _bitsandbytes_version:
+ return False
+ return compare_versions(parse(_bitsandbytes_version), operation, version)
+
+
def is_k_diffusion_version(operation: str, version: str):
"""
Compares the current k-diffusion version to a given reference with an operation.
diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py
index b36664cb81ff..bac24fa23e63 100644
--- a/src/diffusers/utils/loading_utils.py
+++ b/src/diffusers/utils/loading_utils.py
@@ -1,6 +1,6 @@
import os
import tempfile
-from typing import Callable, List, Optional, Union
+from typing import Any, Callable, List, Optional, Tuple, Union
from urllib.parse import unquote, urlparse
import PIL.Image
@@ -135,3 +135,16 @@ def load_video(
pil_images = convert_method(pil_images)
return pil_images
+
+
+# Taken from `transformers`.
+def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]:
+ if "." in tensor_name:
+ splits = tensor_name.split(".")
+ for split in splits[:-1]:
+ new_module = getattr(module, split)
+ if new_module is None:
+ raise ValueError(f"{module} has no attribute {split}.")
+ module = new_module
+ tensor_name = splits[-1]
+ return module, tensor_name
diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py
index ca55192ff7ae..dcc78a547a13 100644
--- a/src/diffusers/utils/peft_utils.py
+++ b/src/diffusers/utils/peft_utils.py
@@ -134,14 +134,14 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
"""
from peft.tuners.tuners_utils import BaseTunerLayer
- if weight == 1.0:
+ if weight is None or weight == 1.0:
return
for module in model.modules():
if isinstance(module, BaseTunerLayer):
- if weight is not None and weight != 0:
+ if weight != 0:
module.unscale_layer(weight)
- elif weight is not None and weight == 0:
+ else:
for adapter_name in module.active_adapters:
# if weight == 0 unscale should re-set the scale to the original value.
module.set_scale(adapter_name, 1.0)
diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py
index a2f283d0c4f5..6361cca663b9 100644
--- a/src/diffusers/utils/testing_utils.py
+++ b/src/diffusers/utils/testing_utils.py
@@ -1,5 +1,6 @@
import functools
import importlib
+import importlib.metadata
import inspect
import io
import logging
@@ -27,6 +28,8 @@
from .import_utils import (
BACKENDS_MAPPING,
+ is_accelerate_available,
+ is_bitsandbytes_available,
is_compel_available,
is_flax_available,
is_note_seq_available,
@@ -371,6 +374,20 @@ def require_timm(test_case):
return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
+def require_bitsandbytes(test_case):
+ """
+ Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed.
+ """
+ return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
+
+
+def require_accelerate(test_case):
+ """
+ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
+ """
+ return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
+
+
def require_peft_version_greater(peft_version):
"""
Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
@@ -408,7 +425,7 @@ def decorator(test_case):
def require_accelerate_version_greater(accelerate_version):
def decorator(test_case):
- correct_accelerate_version = is_peft_available() and version.parse(
+ correct_accelerate_version = is_accelerate_available() and version.parse(
version.parse(importlib.metadata.version("accelerate")).base_version
) > version.parse(accelerate_version)
return unittest.skipUnless(
@@ -418,6 +435,18 @@ def decorator(test_case):
return decorator
+def require_bitsandbytes_version_greater(bnb_version):
+ def decorator(test_case):
+ correct_bnb_version = is_bitsandbytes_available() and version.parse(
+ version.parse(importlib.metadata.version("bitsandbytes")).base_version
+ ) > version.parse(bnb_version)
+ return unittest.skipUnless(
+ correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}."
+ )(test_case)
+
+ return decorator
+
+
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend
diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py
index 2b9084327289..2be4744c5ac4 100644
--- a/tests/models/transformers/test_models_transformer_sd3.py
+++ b/tests/models/transformers/test_models_transformer_sd3.py
@@ -73,6 +73,65 @@ def prepare_init_args_and_inputs_for_common(self):
"joint_attention_dim": 32,
"pooled_projection_dim": 64,
"out_channels": 4,
+ "pos_embed_max_size": 96,
+ "dual_attention_layers": (),
+ "qk_norm": None,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ @unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
+ def test_set_attn_processor_for_determinism(self):
+ pass
+
+
+class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = SD3Transformer2DModel
+ main_input_name = "hidden_states"
+
+ @property
+ def dummy_input(self):
+ batch_size = 2
+ num_channels = 4
+ height = width = embedding_dim = 32
+ pooled_embedding_dim = embedding_dim * 2
+ sequence_length = 154
+
+ hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "pooled_projections": pooled_prompt_embeds,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (4, 32, 32)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "sample_size": 32,
+ "patch_size": 1,
+ "in_channels": 4,
+ "num_layers": 2,
+ "attention_head_dim": 8,
+ "num_attention_heads": 4,
+ "caption_projection_dim": 32,
+ "joint_attention_dim": 32,
+ "pooled_projection_dim": 64,
+ "out_channels": 4,
+ "pos_embed_max_size": 96,
+ "dual_attention_layers": (0,),
+ "qk_norm": "rms_norm",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
new file mode 100644
index 000000000000..2a51fc65798c
--- /dev/null
+++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py
@@ -0,0 +1,324 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLCogVideoX, CogVideoXFunControlPipeline, CogVideoXTransformer3DModel, DDIMScheduler
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+ check_qkv_fusion_matches_attn_procs_length,
+ check_qkv_fusion_processors_exist,
+ to_np,
+)
+
+
+enable_full_determinism()
+
+
+class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = CogVideoXFunControlPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"control_video"})
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = CogVideoXTransformer3DModel(
+ # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
+ # But, since we are using tiny-random-t5 here, we need the internal dim of CogVideoXTransformer3DModel
+ # to be 32. The internal dim is product of num_attention_heads and attention_head_dim
+ num_attention_heads=4,
+ attention_head_dim=8,
+ in_channels=8,
+ out_channels=4,
+ time_embed_dim=2,
+ text_embed_dim=32, # Must match with tiny-random-t5
+ num_layers=1,
+ sample_width=2, # latent width: 2 -> final width: 16
+ sample_height=2, # latent height: 2 -> final height: 16
+ sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9
+ patch_size=2,
+ temporal_compression_ratio=4,
+ max_text_seq_length=16,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLCogVideoX(
+ in_channels=3,
+ out_channels=3,
+ down_block_types=(
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ "CogVideoXDownBlock3D",
+ ),
+ up_block_types=(
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ "CogVideoXUpBlock3D",
+ ),
+ block_out_channels=(8, 8, 8, 8),
+ latent_channels=4,
+ layers_per_block=1,
+ norm_num_groups=2,
+ temporal_compression_ratio=4,
+ )
+
+ torch.manual_seed(0)
+ scheduler = DDIMScheduler()
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed: int = 0, num_frames: int = 8):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ # Cannot reduce because convolution kernel becomes bigger than sample
+ height = 16
+ width = 16
+
+ control_video = [Image.new("RGB", (width, height))] * num_frames
+
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "",
+ "control_video": control_video,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": height,
+ "width": width,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (8, 3, 16, 16))
+ expected_video = torch.randn(8, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+
+ # iterate over callback args
+ for tensor_name, tensor_value in callback_kwargs.items():
+ # check that we're only passing in allowed tensor inputs
+ assert tensor_name in pipe._callback_tensor_inputs
+
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # Test passing in a subset
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ output = pipe(**inputs)[0]
+
+ # Test passing in a everything
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+
+ def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
+ is_last = i == (pipe.num_timesteps - 1)
+ if is_last:
+ callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
+ return callback_kwargs
+
+ inputs["callback_on_step_end"] = callback_inputs_change_tensor
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ output = pipe(**inputs)[0]
+ assert output.abs().sum() < 1e10
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.5):
+ # NOTE(aryan): This requires a higher expected_max_diff than other CogVideoX pipelines
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_overlap_factor_height=1 / 12,
+ tile_overlap_factor_width=1 / 12,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ frames = pipe(**inputs).frames # [B, F, C, H, W]
+ original_image_slice = frames[0, -2:, -1, -3:, -3:]
+
+ pipe.fuse_qkv_projections()
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_matches_attn_procs_length(
+ pipe.transformer, pipe.transformer.original_attn_processors
+ ), "Something wrong with the attention processors concerning the fused QKV projections."
+
+ inputs = self.get_dummy_inputs(device)
+ frames = pipe(**inputs).frames
+ image_slice_fused = frames[0, -2:, -1, -3:, -3:]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ frames = pipe(**inputs).frames
+ image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
+
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
diff --git a/tests/pipelines/cogvideo/test_cogvideox_image2video.py b/tests/pipelines/cogvideo/test_cogvideox_image2video.py
index ec9a5fdd153e..f7e1fe7fd6c7 100644
--- a/tests/pipelines/cogvideo/test_cogvideox_image2video.py
+++ b/tests/pipelines/cogvideo/test_cogvideox_image2video.py
@@ -43,7 +43,7 @@
enable_full_determinism()
-class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = CogVideoXImageToVideoPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
@@ -343,7 +343,6 @@ def test_fused_qkv_projections(self):
), "Original outputs should match when fused QKV projections are disabled."
-@unittest.skip("The model 'THUDM/CogVideoX-5b-I2V' is not public yet.")
@slow
@require_torch_gpu
class CogVideoXImageToVideoPipelineIntegrationTests(unittest.TestCase):
diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py
index 697244dcb105..bb3bdc273cc4 100644
--- a/tests/pipelines/test_pipeline_utils.py
+++ b/tests/pipelines/test_pipeline_utils.py
@@ -18,7 +18,7 @@
StableDiffusionPipeline,
UNet2DConditionModel,
)
-from diffusers.pipelines.pipeline_utils import is_safetensors_compatible
+from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible
from diffusers.utils.testing_utils import torch_device
@@ -197,6 +197,18 @@ def test_diffusers_is_compatible_only_variants(self):
]
self.assertTrue(is_safetensors_compatible(filenames))
+ def test_diffusers_is_compatible_no_components(self):
+ filenames = [
+ "diffusion_pytorch_model.bin",
+ ]
+ self.assertFalse(is_safetensors_compatible(filenames))
+
+ def test_diffusers_is_compatible_no_components_only_variants(self):
+ filenames = [
+ "diffusion_pytorch_model.fp16.bin",
+ ]
+ self.assertFalse(is_safetensors_compatible(filenames))
+
class ProgressBarTests(unittest.TestCase):
def get_dummy_components_image_generation(self):
diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py
index 8b087db6726e..43b01c40f5bb 100644
--- a/tests/pipelines/test_pipelines.py
+++ b/tests/pipelines/test_pipelines.py
@@ -947,6 +947,27 @@ def test_text_inversion_multi_tokens(self):
emb1[num_tokens + 1].sum().item() == emb2[num_tokens + 1].sum().item() == emb3[num_tokens + 1].sum().item()
)
+ def test_textual_inversion_unload(self):
+ pipe1 = StableDiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
+ )
+ pipe1 = pipe1.to(torch_device)
+ orig_tokenizer_size = len(pipe1.tokenizer)
+ orig_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight)
+
+ token = "<*>"
+ ten = torch.ones((32,))
+ pipe1.load_textual_inversion(ten, token=token)
+ pipe1.unload_textual_inversion()
+ pipe1.load_textual_inversion(ten, token=token)
+ pipe1.unload_textual_inversion()
+
+ final_tokenizer_size = len(pipe1.tokenizer)
+ final_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight)
+ # both should be restored to original size
+ assert final_tokenizer_size == orig_tokenizer_size
+ assert final_emb_size == orig_emb_size
+
def test_download_ignore_files(self):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with tempfile.TemporaryDirectory() as tmpdirname:
diff --git a/tests/quantization/bnb/README.md b/tests/quantization/bnb/README.md
new file mode 100644
index 000000000000..f1585581597d
--- /dev/null
+++ b/tests/quantization/bnb/README.md
@@ -0,0 +1,44 @@
+The tests here are adapted from [`transformers` tests](https://github.com/huggingface/transformers/tree/409fcfdfccde77a14b7cc36972b774cabc371ae1/tests/quantization/bnb).
+
+They were conducted on the `audace` machine, using a single RTX 4090. Below is `nvidia-smi`:
+
+```bash
++-----------------------------------------------------------------------------------------+
+| NVIDIA-SMI 550.90.07 Driver Version: 550.90.07 CUDA Version: 12.4 |
+|-----------------------------------------+------------------------+----------------------+
+| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
+| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
+| | | MIG M. |
+|=========================================+========================+======================|
+| 0 NVIDIA GeForce RTX 4090 Off | 00000000:01:00.0 Off | Off |
+| 30% 55C P0 61W / 450W | 1MiB / 24564MiB | 2% Default |
+| | | N/A |
++-----------------------------------------+------------------------+----------------------+
+| 1 NVIDIA GeForce RTX 4090 Off | 00000000:13:00.0 Off | Off |
+| 30% 51C P0 60W / 450W | 1MiB / 24564MiB | 0% Default |
+| | | N/A |
++-----------------------------------------+------------------------+----------------------+
+```
+
+`diffusers-cli`:
+
+```bash
+- 🤗 Diffusers version: 0.31.0.dev0
+- Platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.35
+- Running on Google Colab?: No
+- Python version: 3.10.12
+- PyTorch version (GPU?): 2.5.0.dev20240818+cu124 (True)
+- Flax version (CPU?/GPU?/TPU?): not installed (NA)
+- Jax version: not installed
+- JaxLib version: not installed
+- Huggingface_hub version: 0.24.5
+- Transformers version: 4.44.2
+- Accelerate version: 0.34.0.dev0
+- PEFT version: 0.12.0
+- Bitsandbytes version: 0.43.3
+- Safetensors version: 0.4.4
+- xFormers version: not installed
+- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
+NVIDIA GeForce RTX 4090, 24564 MiB
+- Using GPU in script?: Yes
+```
\ No newline at end of file
diff --git a/tests/quantization/bnb/__init__.py b/tests/quantization/bnb/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py
new file mode 100644
index 000000000000..7b553434fbe9
--- /dev/null
+++ b/tests/quantization/bnb/test_4bit.py
@@ -0,0 +1,627 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a clone of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import gc
+import os
+import tempfile
+import unittest
+
+import numpy as np
+import safetensors.torch
+
+from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
+from diffusers.utils import logging
+from diffusers.utils.testing_utils import (
+ CaptureLogger,
+ is_bitsandbytes_available,
+ is_torch_available,
+ is_transformers_available,
+ load_pt,
+ numpy_cosine_similarity_distance,
+ require_accelerate,
+ require_bitsandbytes_version_greater,
+ require_torch,
+ require_torch_gpu,
+ require_transformers_version_greater,
+ slow,
+ torch_device,
+)
+
+
+def get_some_linear_layer(model):
+ if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
+ return model.transformer_blocks[0].attn.to_q
+ else:
+ return NotImplementedError("Don't know what layer to retrieve here.")
+
+
+if is_transformers_available():
+ from transformers import T5EncoderModel
+
+if is_torch_available():
+ import torch
+ import torch.nn as nn
+
+ class LoRALayer(nn.Module):
+ """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
+
+ Taken from
+ https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
+ """
+
+ def __init__(self, module: nn.Module, rank: int):
+ super().__init__()
+ self.module = module
+ self.adapter = nn.Sequential(
+ nn.Linear(module.in_features, rank, bias=False),
+ nn.Linear(rank, module.out_features, bias=False),
+ )
+ small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
+ nn.init.normal_(self.adapter[0].weight, std=small_std)
+ nn.init.zeros_(self.adapter[1].weight)
+ self.adapter.to(module.weight.device)
+
+ def forward(self, input, *args, **kwargs):
+ return self.module(input, *args, **kwargs) + self.adapter(input)
+
+
+if is_bitsandbytes_available():
+ import bitsandbytes as bnb
+
+
+@require_bitsandbytes_version_greater("0.43.2")
+@require_accelerate
+@require_torch
+@require_torch_gpu
+@slow
+class Base4bitTests(unittest.TestCase):
+ # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
+ # Therefore here we use only SD3 to test our module
+ model_name = "stabilityai/stable-diffusion-3-medium-diffusers"
+
+ # This was obtained on audace so the number might slightly change
+ expected_rel_difference = 3.69
+
+ prompt = "a beautiful sunset amidst the mountains."
+ num_inference_steps = 10
+ seed = 0
+
+ def get_dummy_inputs(self):
+ prompt_embeds = load_pt(
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
+ )
+ pooled_prompt_embeds = load_pt(
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
+ )
+ latent_model_input = load_pt(
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
+ )
+
+ input_dict_for_transformer = {
+ "hidden_states": latent_model_input,
+ "encoder_hidden_states": prompt_embeds,
+ "pooled_projections": pooled_prompt_embeds,
+ "timestep": torch.Tensor([1.0]),
+ "return_dict": False,
+ }
+ return input_dict_for_transformer
+
+
+class BnB4BitBasicTests(Base4bitTests):
+ def setUp(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Models
+ self.model_fp16 = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", torch_dtype=torch.float16
+ )
+ nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.float16,
+ )
+ self.model_4bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=nf4_config
+ )
+
+ def tearDown(self):
+ del self.model_fp16
+ del self.model_4bit
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_quantization_num_parameters(self):
+ r"""
+ Test if the number of returned parameters is correct
+ """
+ num_params_4bit = self.model_4bit.num_parameters()
+ num_params_fp16 = self.model_fp16.num_parameters()
+
+ self.assertEqual(num_params_4bit, num_params_fp16)
+
+ def test_quantization_config_json_serialization(self):
+ r"""
+ A simple test to check if the quantization config is correctly serialized and deserialized
+ """
+ config = self.model_4bit.config
+
+ self.assertTrue("quantization_config" in config)
+
+ _ = config["quantization_config"].to_dict()
+ _ = config["quantization_config"].to_diff_dict()
+
+ _ = config["quantization_config"].to_json_string()
+
+ def test_memory_footprint(self):
+ r"""
+ A simple test to check if the model conversion has been done correctly by checking on the
+ memory footprint of the converted model and the class type of the linear layers of the converted models
+ """
+ mem_fp16 = self.model_fp16.get_memory_footprint()
+ mem_4bit = self.model_4bit.get_memory_footprint()
+
+ self.assertAlmostEqual(mem_fp16 / mem_4bit, self.expected_rel_difference, delta=1e-2)
+ linear = get_some_linear_layer(self.model_4bit)
+ self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
+
+ def test_original_dtype(self):
+ r"""
+ A simple test to check if the model succesfully stores the original dtype
+ """
+ self.assertTrue("_pre_quantization_dtype" in self.model_4bit.config)
+ self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config)
+ self.assertTrue(self.model_4bit.config["_pre_quantization_dtype"] == torch.float16)
+
+ def test_keep_modules_in_fp32(self):
+ r"""
+ A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
+ Also ensures if inference works.
+ """
+ fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules
+ SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"]
+
+ nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.float16,
+ )
+ model = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=nf4_config
+ )
+
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if name in model._keep_in_fp32_modules:
+ self.assertTrue(module.weight.dtype == torch.float32)
+ else:
+ # 4-bit parameters are packed in uint8 variables
+ self.assertTrue(module.weight.dtype == torch.uint8)
+
+ # test if inference works.
+ with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16):
+ input_dict_for_transformer = self.get_dummy_inputs()
+ model_inputs = {
+ k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
+ }
+ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
+ _ = model(**model_inputs)
+
+ SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules
+
+ def test_linear_are_4bit(self):
+ r"""
+ A simple test to check if the model conversion has been done correctly by checking on the
+ memory footprint of the converted model and the class type of the linear layers of the converted models
+ """
+ self.model_fp16.get_memory_footprint()
+ self.model_4bit.get_memory_footprint()
+
+ for name, module in self.model_4bit.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if name not in ["proj_out"]:
+ # 4-bit parameters are packed in uint8 variables
+ self.assertTrue(module.weight.dtype == torch.uint8)
+
+ def test_config_from_pretrained(self):
+ transformer_4bit = FluxTransformer2DModel.from_pretrained(
+ "hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
+ )
+ linear = get_some_linear_layer(transformer_4bit)
+ self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
+ self.assertTrue(hasattr(linear.weight, "quant_state"))
+ self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState)
+
+ def test_device_assignment(self):
+ mem_before = self.model_4bit.get_memory_footprint()
+
+ # Move to CPU
+ self.model_4bit.to("cpu")
+ self.assertEqual(self.model_4bit.device.type, "cpu")
+ self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
+
+ # Move back to CUDA device
+ for device in [0, "cuda", "cuda:0", "call()"]:
+ if device == "call()":
+ self.model_4bit.cuda(0)
+ else:
+ self.model_4bit.to(device)
+ self.assertEqual(self.model_4bit.device, torch.device(0))
+ self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before)
+ self.model_4bit.to("cpu")
+
+ def test_device_and_dtype_assignment(self):
+ r"""
+ Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error.
+ Checks also if other models are casted correctly. Device placement, however, is supported.
+ """
+ with self.assertRaises(ValueError):
+ # Tries with a `dtype`
+ self.model_4bit.to(torch.float16)
+
+ with self.assertRaises(ValueError):
+ # Tries with a `device` and `dtype`
+ self.model_4bit.to(device="cuda:0", dtype=torch.float16)
+
+ with self.assertRaises(ValueError):
+ # Tries with a cast
+ self.model_4bit.float()
+
+ with self.assertRaises(ValueError):
+ # Tries with a cast
+ self.model_4bit.half()
+
+ # This should work
+ self.model_4bit.to("cuda")
+
+ # Test if we did not break anything
+ self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
+ input_dict_for_transformer = self.get_dummy_inputs()
+ model_inputs = {
+ k: v.to(dtype=torch.float32, device=torch_device)
+ for k, v in input_dict_for_transformer.items()
+ if not isinstance(v, bool)
+ }
+ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
+ with torch.no_grad():
+ _ = self.model_fp16(**model_inputs)
+
+ # Check this does not throw an error
+ _ = self.model_fp16.to("cpu")
+
+ # Check this does not throw an error
+ _ = self.model_fp16.half()
+
+ # Check this does not throw an error
+ _ = self.model_fp16.float()
+
+ # Check that this does not throw an error
+ _ = self.model_fp16.cuda()
+
+ def test_bnb_4bit_wrong_config(self):
+ r"""
+ Test whether creating a bnb config with unsupported values leads to errors.
+ """
+ with self.assertRaises(ValueError):
+ _ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add")
+
+ def test_bnb_4bit_errors_loading_incorrect_state_dict(self):
+ r"""
+ Test if loading with an incorrect state dict raises an error.
+ """
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ nf4_config = BitsAndBytesConfig(load_in_4bit=True)
+ model_4bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=nf4_config
+ )
+ model_4bit.save_pretrained(tmpdirname)
+ del model_4bit
+
+ with self.assertRaises(ValueError) as err_context:
+ state_dict = safetensors.torch.load_file(
+ os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors")
+ )
+
+ # corrupt the state dict
+ key_to_target = "context_embedder.weight" # can be other keys too.
+ compatible_param = state_dict[key_to_target]
+ corrupted_param = torch.randn(compatible_param.shape[0] - 1, 1)
+ state_dict[key_to_target] = bnb.nn.Params4bit(corrupted_param, requires_grad=False)
+ safetensors.torch.save_file(
+ state_dict, os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors")
+ )
+
+ _ = SD3Transformer2DModel.from_pretrained(tmpdirname)
+
+ assert key_to_target in str(err_context.exception)
+
+
+class BnB4BitTrainingTests(Base4bitTests):
+ def setUp(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.float16,
+ )
+ self.model_4bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=nf4_config
+ )
+
+ def test_training(self):
+ # Step 1: freeze all parameters
+ for param in self.model_4bit.parameters():
+ param.requires_grad = False # freeze the model - train adapters later
+ if param.ndim == 1:
+ # cast the small parameters (e.g. layernorm) to fp32 for stability
+ param.data = param.data.to(torch.float32)
+
+ # Step 2: add adapters
+ for _, module in self.model_4bit.named_modules():
+ if "Attention" in repr(type(module)):
+ module.to_k = LoRALayer(module.to_k, rank=4)
+ module.to_q = LoRALayer(module.to_q, rank=4)
+ module.to_v = LoRALayer(module.to_v, rank=4)
+
+ # Step 3: dummy batch
+ input_dict_for_transformer = self.get_dummy_inputs()
+ model_inputs = {
+ k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
+ }
+ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
+
+ # Step 4: Check if the gradient is not None
+ with torch.amp.autocast("cuda", dtype=torch.float16):
+ out = self.model_4bit(**model_inputs)[0]
+ out.norm().backward()
+
+ for module in self.model_4bit.modules():
+ if isinstance(module, LoRALayer):
+ self.assertTrue(module.adapter[1].weight.grad is not None)
+ self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
+
+
+@require_transformers_version_greater("4.44.0")
+class SlowBnb4BitTests(Base4bitTests):
+ def setUp(self) -> None:
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.float16,
+ )
+ model_4bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=nf4_config
+ )
+ self.pipeline_4bit = DiffusionPipeline.from_pretrained(
+ self.model_name, transformer=model_4bit, torch_dtype=torch.float16
+ )
+ self.pipeline_4bit.enable_model_cpu_offload()
+
+ def tearDown(self):
+ del self.pipeline_4bit
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_quality(self):
+ output = self.pipeline_4bit(
+ prompt=self.prompt,
+ num_inference_steps=self.num_inference_steps,
+ generator=torch.manual_seed(self.seed),
+ output_type="np",
+ ).images
+
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.1123, 0.1296, 0.1609, 0.1042, 0.1230, 0.1274, 0.0928, 0.1165, 0.1216])
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ print(f"{max_diff=}")
+ self.assertTrue(max_diff < 1e-2)
+
+ def test_generate_quality_dequantize(self):
+ r"""
+ Test that loading the model and unquantize it produce correct results.
+ """
+ self.pipeline_4bit.transformer.dequantize()
+ output = self.pipeline_4bit(
+ prompt=self.prompt,
+ num_inference_steps=self.num_inference_steps,
+ generator=torch.manual_seed(self.seed),
+ output_type="np",
+ ).images
+
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.1216, 0.1387, 0.1584, 0.1152, 0.1318, 0.1282, 0.1062, 0.1226, 0.1228])
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ self.assertTrue(max_diff < 1e-3)
+
+ # Since we offloaded the `pipeline_4bit.transformer` to CPU (result of `enable_model_cpu_offload()), check
+ # the following.
+ self.assertTrue(self.pipeline_4bit.transformer.device.type == "cpu")
+ # calling it again shouldn't be a problem
+ _ = self.pipeline_4bit(
+ prompt=self.prompt,
+ num_inference_steps=2,
+ generator=torch.manual_seed(self.seed),
+ output_type="np",
+ ).images
+
+ def test_moving_to_cpu_throws_warning(self):
+ nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+ bnb_4bit_compute_dtype=torch.float16,
+ )
+ model_4bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=nf4_config
+ )
+
+ logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
+ logger.setLevel(30)
+ with CaptureLogger(logger) as cap_logger:
+ # Because `model.dtype` will return torch.float16 as SD3 transformer has
+ # a conv layer as the first layer.
+ _ = DiffusionPipeline.from_pretrained(
+ self.model_name, transformer=model_4bit, torch_dtype=torch.float16
+ ).to("cpu")
+
+ assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out
+
+
+@require_transformers_version_greater("4.44.0")
+class SlowBnb4BitFluxTests(Base4bitTests):
+ def setUp(self) -> None:
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"
+ t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
+ transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
+ self.pipeline_4bit = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ text_encoder_2=t5_4bit,
+ transformer=transformer_4bit,
+ torch_dtype=torch.float16,
+ )
+ self.pipeline_4bit.enable_model_cpu_offload()
+
+ def tearDown(self):
+ del self.pipeline_4bit
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_quality(self):
+ # keep the resolution and max tokens to a lower number for faster execution.
+ output = self.pipeline_4bit(
+ prompt=self.prompt,
+ num_inference_steps=self.num_inference_steps,
+ generator=torch.manual_seed(self.seed),
+ height=256,
+ width=256,
+ max_sequence_length=64,
+ output_type="np",
+ ).images
+
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.0583, 0.0586, 0.0632, 0.0815, 0.0813, 0.0947, 0.1040, 0.1145, 0.1265])
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ self.assertTrue(max_diff < 1e-3)
+
+
+@slow
+class BaseBnb4BitSerializationTests(Base4bitTests):
+ def tearDown(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True):
+ r"""
+ Test whether it is possible to serialize a model in 4-bit. Uses most typical params as default.
+ See ExtendedSerializationTest class for more params combinations.
+ """
+
+ self.quantization_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type=quant_type,
+ bnb_4bit_use_double_quant=double_quant,
+ bnb_4bit_compute_dtype=torch.bfloat16,
+ )
+ model_0 = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=self.quantization_config
+ )
+ self.assertTrue("_pre_quantization_dtype" in model_0.config)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model_0.save_pretrained(tmpdirname, safe_serialization=safe_serialization)
+
+ config = SD3Transformer2DModel.load_config(tmpdirname)
+ self.assertTrue("quantization_config" in config)
+ self.assertTrue("_pre_quantization_dtype" not in config)
+
+ model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname)
+
+ # checking quantized linear module weight
+ linear = get_some_linear_layer(model_1)
+ self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
+ self.assertTrue(hasattr(linear.weight, "quant_state"))
+ self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState)
+
+ # checking memory footpring
+ self.assertAlmostEqual(model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2)
+
+ # Matching all parameters and their quant_state items:
+ d0 = dict(model_0.named_parameters())
+ d1 = dict(model_1.named_parameters())
+ self.assertTrue(d0.keys() == d1.keys())
+
+ for k in d0.keys():
+ self.assertTrue(d0[k].shape == d1[k].shape)
+ self.assertTrue(d0[k].device.type == d1[k].device.type)
+ self.assertTrue(d0[k].device == d1[k].device)
+ self.assertTrue(d0[k].dtype == d1[k].dtype)
+ self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device)))
+
+ if isinstance(d0[k], bnb.nn.modules.Params4bit):
+ for v0, v1 in zip(
+ d0[k].quant_state.as_dict().values(),
+ d1[k].quant_state.as_dict().values(),
+ ):
+ if isinstance(v0, torch.Tensor):
+ self.assertTrue(torch.equal(v0, v1.to(v0.device)))
+ else:
+ self.assertTrue(v0 == v1)
+
+ # comparing forward() outputs
+ dummy_inputs = self.get_dummy_inputs()
+ inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)}
+ inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs})
+ out_0 = model_0(**inputs)[0]
+ out_1 = model_1(**inputs)[0]
+ self.assertTrue(torch.equal(out_0, out_1))
+
+
+class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
+ """
+ tests more combinations of parameters
+ """
+
+ def test_nf4_single_unsafe(self):
+ self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=False)
+
+ def test_nf4_single_safe(self):
+ self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=True)
+
+ def test_nf4_double_unsafe(self):
+ self.test_serialization(quant_type="nf4", double_quant=True, safe_serialization=False)
+
+ # nf4 double safetensors quantization is tested in test_serialization() method from the parent class
+
+ def test_fp4_single_unsafe(self):
+ self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=False)
+
+ def test_fp4_single_safe(self):
+ self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=True)
+
+ def test_fp4_double_unsafe(self):
+ self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=False)
+
+ def test_fp4_double_safe(self):
+ self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)
diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py
new file mode 100644
index 000000000000..ba2402461c87
--- /dev/null
+++ b/tests/quantization/bnb/test_mixed_int8.py
@@ -0,0 +1,552 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Team Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a clone of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import gc
+import tempfile
+import unittest
+
+import numpy as np
+
+from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
+from diffusers.utils.testing_utils import (
+ CaptureLogger,
+ is_bitsandbytes_available,
+ is_torch_available,
+ is_transformers_available,
+ load_pt,
+ numpy_cosine_similarity_distance,
+ require_accelerate,
+ require_bitsandbytes_version_greater,
+ require_torch,
+ require_torch_gpu,
+ require_transformers_version_greater,
+ slow,
+ torch_device,
+)
+
+
+def get_some_linear_layer(model):
+ if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]:
+ return model.transformer_blocks[0].attn.to_q
+ else:
+ return NotImplementedError("Don't know what layer to retrieve here.")
+
+
+if is_transformers_available():
+ from transformers import T5EncoderModel
+
+if is_torch_available():
+ import torch
+ import torch.nn as nn
+
+ class LoRALayer(nn.Module):
+ """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only
+
+ Taken from
+ https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77
+ """
+
+ def __init__(self, module: nn.Module, rank: int):
+ super().__init__()
+ self.module = module
+ self.adapter = nn.Sequential(
+ nn.Linear(module.in_features, rank, bias=False),
+ nn.Linear(rank, module.out_features, bias=False),
+ )
+ small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
+ nn.init.normal_(self.adapter[0].weight, std=small_std)
+ nn.init.zeros_(self.adapter[1].weight)
+ self.adapter.to(module.weight.device)
+
+ def forward(self, input, *args, **kwargs):
+ return self.module(input, *args, **kwargs) + self.adapter(input)
+
+
+if is_bitsandbytes_available():
+ import bitsandbytes as bnb
+
+
+@require_bitsandbytes_version_greater("0.43.2")
+@require_accelerate
+@require_torch
+@require_torch_gpu
+@slow
+class Base8bitTests(unittest.TestCase):
+ # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
+ # Therefore here we use only SD3 to test our module
+ model_name = "stabilityai/stable-diffusion-3-medium-diffusers"
+
+ # This was obtained on audace so the number might slightly change
+ expected_rel_difference = 1.94
+
+ prompt = "a beautiful sunset amidst the mountains."
+ num_inference_steps = 10
+ seed = 0
+
+ def get_dummy_inputs(self):
+ prompt_embeds = load_pt(
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt"
+ )
+ pooled_prompt_embeds = load_pt(
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt"
+ )
+ latent_model_input = load_pt(
+ "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt"
+ )
+
+ input_dict_for_transformer = {
+ "hidden_states": latent_model_input,
+ "encoder_hidden_states": prompt_embeds,
+ "pooled_projections": pooled_prompt_embeds,
+ "timestep": torch.Tensor([1.0]),
+ "return_dict": False,
+ }
+ return input_dict_for_transformer
+
+
+class BnB8bitBasicTests(Base8bitTests):
+ def setUp(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # Models
+ self.model_fp16 = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", torch_dtype=torch.float16
+ )
+ mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
+ self.model_8bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
+ )
+
+ def tearDown(self):
+ del self.model_fp16
+ del self.model_8bit
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_quantization_num_parameters(self):
+ r"""
+ Test if the number of returned parameters is correct
+ """
+ num_params_8bit = self.model_8bit.num_parameters()
+ num_params_fp16 = self.model_fp16.num_parameters()
+
+ self.assertEqual(num_params_8bit, num_params_fp16)
+
+ def test_quantization_config_json_serialization(self):
+ r"""
+ A simple test to check if the quantization config is correctly serialized and deserialized
+ """
+ config = self.model_8bit.config
+
+ self.assertTrue("quantization_config" in config)
+
+ _ = config["quantization_config"].to_dict()
+ _ = config["quantization_config"].to_diff_dict()
+
+ _ = config["quantization_config"].to_json_string()
+
+ def test_memory_footprint(self):
+ r"""
+ A simple test to check if the model conversion has been done correctly by checking on the
+ memory footprint of the converted model and the class type of the linear layers of the converted models
+ """
+ mem_fp16 = self.model_fp16.get_memory_footprint()
+ mem_8bit = self.model_8bit.get_memory_footprint()
+
+ self.assertAlmostEqual(mem_fp16 / mem_8bit, self.expected_rel_difference, delta=1e-2)
+ linear = get_some_linear_layer(self.model_8bit)
+ self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
+
+ def test_original_dtype(self):
+ r"""
+ A simple test to check if the model succesfully stores the original dtype
+ """
+ self.assertTrue("_pre_quantization_dtype" in self.model_8bit.config)
+ self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config)
+ self.assertTrue(self.model_8bit.config["_pre_quantization_dtype"] == torch.float16)
+
+ def test_keep_modules_in_fp32(self):
+ r"""
+ A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32.
+ Also ensures if inference works.
+ """
+ fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules
+ SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"]
+
+ mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
+ model = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
+ )
+
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if name in model._keep_in_fp32_modules:
+ self.assertTrue(module.weight.dtype == torch.float32)
+ else:
+ # 8-bit parameters are packed in int8 variables
+ self.assertTrue(module.weight.dtype == torch.int8)
+
+ # test if inference works.
+ with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16):
+ input_dict_for_transformer = self.get_dummy_inputs()
+ model_inputs = {
+ k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
+ }
+ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
+ _ = model(**model_inputs)
+
+ SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules
+
+ def test_linear_are_8bit(self):
+ r"""
+ A simple test to check if the model conversion has been done correctly by checking on the
+ memory footprint of the converted model and the class type of the linear layers of the converted models
+ """
+ self.model_fp16.get_memory_footprint()
+ self.model_8bit.get_memory_footprint()
+
+ for name, module in self.model_8bit.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if name not in ["proj_out"]:
+ # 8-bit parameters are packed in int8 variables
+ self.assertTrue(module.weight.dtype == torch.int8)
+
+ def test_llm_skip(self):
+ r"""
+ A simple test to check if `llm_int8_skip_modules` works as expected
+ """
+ config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out"])
+ model_8bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=config
+ )
+ linear = get_some_linear_layer(model_8bit)
+ self.assertTrue(linear.weight.dtype == torch.int8)
+ self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt))
+
+ self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear))
+ self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8)
+
+ def test_config_from_pretrained(self):
+ transformer_8bit = FluxTransformer2DModel.from_pretrained(
+ "hf-internal-testing/flux.1-dev-int8-pkg", subfolder="transformer"
+ )
+ linear = get_some_linear_layer(transformer_8bit)
+ self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
+ self.assertTrue(hasattr(linear.weight, "SCB"))
+
+ def test_device_and_dtype_assignment(self):
+ r"""
+ Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
+ Checks also if other models are casted correctly.
+ """
+ with self.assertRaises(ValueError):
+ # Tries with `str`
+ self.model_8bit.to("cpu")
+
+ with self.assertRaises(ValueError):
+ # Tries with a `dtype``
+ self.model_8bit.to(torch.float16)
+
+ with self.assertRaises(ValueError):
+ # Tries with a `device`
+ self.model_8bit.to(torch.device("cuda:0"))
+
+ with self.assertRaises(ValueError):
+ # Tries with a `device`
+ self.model_8bit.float()
+
+ with self.assertRaises(ValueError):
+ # Tries with a `device`
+ self.model_8bit.half()
+
+ # Test if we did not break anything
+ self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
+ input_dict_for_transformer = self.get_dummy_inputs()
+ model_inputs = {
+ k: v.to(dtype=torch.float32, device=torch_device)
+ for k, v in input_dict_for_transformer.items()
+ if not isinstance(v, bool)
+ }
+ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
+ with torch.no_grad():
+ _ = self.model_fp16(**model_inputs)
+
+ # Check this does not throw an error
+ _ = self.model_fp16.to("cpu")
+
+ # Check this does not throw an error
+ _ = self.model_fp16.half()
+
+ # Check this does not throw an error
+ _ = self.model_fp16.float()
+
+ # Check that this does not throw an error
+ _ = self.model_fp16.cuda()
+
+
+class BnB8bitTrainingTests(Base8bitTests):
+ def setUp(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
+ self.model_8bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
+ )
+
+ def test_training(self):
+ # Step 1: freeze all parameters
+ for param in self.model_8bit.parameters():
+ param.requires_grad = False # freeze the model - train adapters later
+ if param.ndim == 1:
+ # cast the small parameters (e.g. layernorm) to fp32 for stability
+ param.data = param.data.to(torch.float32)
+
+ # Step 2: add adapters
+ for _, module in self.model_8bit.named_modules():
+ if "Attention" in repr(type(module)):
+ module.to_k = LoRALayer(module.to_k, rank=4)
+ module.to_q = LoRALayer(module.to_q, rank=4)
+ module.to_v = LoRALayer(module.to_v, rank=4)
+
+ # Step 3: dummy batch
+ input_dict_for_transformer = self.get_dummy_inputs()
+ model_inputs = {
+ k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool)
+ }
+ model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs})
+
+ # Step 4: Check if the gradient is not None
+ with torch.amp.autocast("cuda", dtype=torch.float16):
+ out = self.model_8bit(**model_inputs)[0]
+ out.norm().backward()
+
+ for module in self.model_8bit.modules():
+ if isinstance(module, LoRALayer):
+ self.assertTrue(module.adapter[1].weight.grad is not None)
+ self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
+
+
+@require_transformers_version_greater("4.44.0")
+class SlowBnb8bitTests(Base8bitTests):
+ def setUp(self) -> None:
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
+ model_8bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
+ )
+ self.pipeline_8bit = DiffusionPipeline.from_pretrained(
+ self.model_name, transformer=model_8bit, torch_dtype=torch.float16
+ )
+ self.pipeline_8bit.enable_model_cpu_offload()
+
+ def tearDown(self):
+ del self.pipeline_8bit
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_quality(self):
+ output = self.pipeline_8bit(
+ prompt=self.prompt,
+ num_inference_steps=self.num_inference_steps,
+ generator=torch.manual_seed(self.seed),
+ output_type="np",
+ ).images
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.0149, 0.0322, 0.0073, 0.0134, 0.0332, 0.011, 0.002, 0.0232, 0.0193])
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ self.assertTrue(max_diff < 1e-2)
+
+ def test_model_cpu_offload_raises_warning(self):
+ model_8bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True)
+ )
+ pipeline_8bit = DiffusionPipeline.from_pretrained(
+ self.model_name, transformer=model_8bit, torch_dtype=torch.float16
+ )
+ logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
+ logger.setLevel(30)
+
+ with CaptureLogger(logger) as cap_logger:
+ pipeline_8bit.enable_model_cpu_offload()
+
+ assert "has been loaded in `bitsandbytes` 8bit" in cap_logger.out
+
+ def test_moving_to_cpu_throws_warning(self):
+ model_8bit = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True)
+ )
+ logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
+ logger.setLevel(30)
+
+ with CaptureLogger(logger) as cap_logger:
+ # Because `model.dtype` will return torch.float16 as SD3 transformer has
+ # a conv layer as the first layer.
+ _ = DiffusionPipeline.from_pretrained(
+ self.model_name, transformer=model_8bit, torch_dtype=torch.float16
+ ).to("cpu")
+
+ assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out
+
+ def test_generate_quality_dequantize(self):
+ r"""
+ Test that loading the model and unquantize it produce correct results.
+ """
+ self.pipeline_8bit.transformer.dequantize()
+ output = self.pipeline_8bit(
+ prompt=self.prompt,
+ num_inference_steps=self.num_inference_steps,
+ generator=torch.manual_seed(self.seed),
+ output_type="np",
+ ).images
+
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.0266, 0.0264, 0.0271, 0.0110, 0.0310, 0.0098, 0.0078, 0.0256, 0.0208])
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ self.assertTrue(max_diff < 1e-2)
+
+ # 8bit models cannot be offloaded to CPU.
+ self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda")
+ # calling it again shouldn't be a problem
+ _ = self.pipeline_8bit(
+ prompt=self.prompt,
+ num_inference_steps=2,
+ generator=torch.manual_seed(self.seed),
+ output_type="np",
+ ).images
+
+
+@require_transformers_version_greater("4.44.0")
+class SlowBnb8bitFluxTests(Base8bitTests):
+ def setUp(self) -> None:
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ model_id = "hf-internal-testing/flux.1-dev-int8-pkg"
+ t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
+ transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
+ self.pipeline_8bit = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ text_encoder_2=t5_8bit,
+ transformer=transformer_8bit,
+ torch_dtype=torch.float16,
+ )
+ self.pipeline_8bit.enable_model_cpu_offload()
+
+ def tearDown(self):
+ del self.pipeline_8bit
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_quality(self):
+ # keep the resolution and max tokens to a lower number for faster execution.
+ output = self.pipeline_8bit(
+ prompt=self.prompt,
+ num_inference_steps=self.num_inference_steps,
+ generator=torch.manual_seed(self.seed),
+ height=256,
+ width=256,
+ max_sequence_length=64,
+ output_type="np",
+ ).images
+ out_slice = output[0, -3:, -3:, -1].flatten()
+ expected_slice = np.array([0.0574, 0.0554, 0.0581, 0.0686, 0.0676, 0.0759, 0.0757, 0.0803, 0.0930])
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
+ self.assertTrue(max_diff < 1e-3)
+
+
+@slow
+class BaseBnb8bitSerializationTests(Base8bitTests):
+ def setUp(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ quantization_config = BitsAndBytesConfig(
+ load_in_8bit=True,
+ )
+ self.model_0 = SD3Transformer2DModel.from_pretrained(
+ self.model_name, subfolder="transformer", quantization_config=quantization_config
+ )
+
+ def tearDown(self):
+ del self.model_0
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_serialization(self):
+ r"""
+ Test whether it is possible to serialize a model in 8-bit. Uses most typical params as default.
+ """
+ self.assertTrue("_pre_quantization_dtype" in self.model_0.config)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ self.model_0.save_pretrained(tmpdirname)
+
+ config = SD3Transformer2DModel.load_config(tmpdirname)
+ self.assertTrue("quantization_config" in config)
+ self.assertTrue("_pre_quantization_dtype" not in config)
+
+ model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname)
+
+ # checking quantized linear module weight
+ linear = get_some_linear_layer(model_1)
+ self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
+ self.assertTrue(hasattr(linear.weight, "SCB"))
+
+ # checking memory footpring
+ self.assertAlmostEqual(self.model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2)
+
+ # Matching all parameters and their quant_state items:
+ d0 = dict(self.model_0.named_parameters())
+ d1 = dict(model_1.named_parameters())
+ self.assertTrue(d0.keys() == d1.keys())
+
+ # comparing forward() outputs
+ dummy_inputs = self.get_dummy_inputs()
+ inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)}
+ inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs})
+ out_0 = self.model_0(**inputs)[0]
+ out_1 = model_1(**inputs)[0]
+ self.assertTrue(torch.equal(out_0, out_1))
+
+ def test_serialization_sharded(self):
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ self.model_0.save_pretrained(tmpdirname, max_shard_size="200MB")
+
+ config = SD3Transformer2DModel.load_config(tmpdirname)
+ self.assertTrue("quantization_config" in config)
+ self.assertTrue("_pre_quantization_dtype" not in config)
+
+ model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname)
+
+ # checking quantized linear module weight
+ linear = get_some_linear_layer(model_1)
+ self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
+ self.assertTrue(hasattr(linear.weight, "SCB"))
+
+ # comparing forward() outputs
+ dummy_inputs = self.get_dummy_inputs()
+ inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)}
+ inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs})
+ out_0 = self.model_0(**inputs)[0]
+ out_1 = model_1(**inputs)[0]
+ self.assertTrue(torch.equal(out_0, out_1))