diff --git a/.github/workflows/pr_test_peft_backend.yml b/.github/workflows/pr_test_peft_backend.yml index 26b4b5a3deec..190e5d26e6f3 100644 --- a/.github/workflows/pr_test_peft_backend.yml +++ b/.github/workflows/pr_test_peft_backend.yml @@ -92,12 +92,14 @@ jobs: run: | python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" python -m uv pip install -e [quality,test] + # TODO (sayakpaul, DN6): revisit `--no-deps` if [ "${{ matrix.lib-versions }}" == "main" ]; then - python -m pip install -U peft@git+https://github.com/huggingface/peft.git - python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git - pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git + python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps + python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps + pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps else - python -m uv pip install -U peft transformers accelerate + python -m uv pip install -U peft --no-deps + python -m uv pip install -U transformers accelerate --no-deps fi - name: Environment diff --git a/docker/diffusers-onnxruntime-cuda/Dockerfile b/docker/diffusers-onnxruntime-cuda/Dockerfile index bd1d871033c9..6124172e109e 100644 --- a/docker/diffusers-onnxruntime-cuda/Dockerfile +++ b/docker/diffusers-onnxruntime-cuda/Dockerfile @@ -28,7 +28,7 @@ ENV PATH="/opt/venv/bin:$PATH" # pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \ python3.10 -m uv pip install --no-cache-dir \ - torch \ + "torch<2.5.0" \ torchvision \ torchaudio \ "onnxruntime-gpu>=1.13.1" \ diff --git a/docker/diffusers-pytorch-compile-cuda/Dockerfile b/docker/diffusers-pytorch-compile-cuda/Dockerfile index cb4a9c0f9896..9d7578f5a4dc 100644 --- a/docker/diffusers-pytorch-compile-cuda/Dockerfile +++ b/docker/diffusers-pytorch-compile-cuda/Dockerfile @@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH" # pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \ python3.10 -m uv pip install --no-cache-dir \ - torch \ + "torch<2.5.0" \ torchvision \ torchaudio \ invisible_watermark && \ diff --git a/docker/diffusers-pytorch-cpu/Dockerfile b/docker/diffusers-pytorch-cpu/Dockerfile index 8d98c52598d2..1b39e58ca273 100644 --- a/docker/diffusers-pytorch-cpu/Dockerfile +++ b/docker/diffusers-pytorch-cpu/Dockerfile @@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH" # pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \ python3.10 -m uv pip install --no-cache-dir \ - torch \ + "torch<2.5.0" \ torchvision \ torchaudio \ invisible_watermark \ diff --git a/docker/diffusers-pytorch-cuda/Dockerfile b/docker/diffusers-pytorch-cuda/Dockerfile index 695f5ed08dc5..7317ef642aa5 100644 --- a/docker/diffusers-pytorch-cuda/Dockerfile +++ b/docker/diffusers-pytorch-cuda/Dockerfile @@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH" # pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \ python3.10 -m uv pip install --no-cache-dir \ - torch \ + "torch<2.5.0" \ torchvision \ torchaudio \ invisible_watermark && \ diff --git a/docker/diffusers-pytorch-xformers-cuda/Dockerfile b/docker/diffusers-pytorch-xformers-cuda/Dockerfile index 1693eb293024..356445a6d173 100644 --- a/docker/diffusers-pytorch-xformers-cuda/Dockerfile +++ b/docker/diffusers-pytorch-xformers-cuda/Dockerfile @@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH" # pre-install the heavy dependencies (these can later be overridden by the deps from setup.py) RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \ python3.10 -m pip install --no-cache-dir \ - torch \ + "torch<2.5.0" \ torchvision \ torchaudio \ invisible_watermark && \ diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index e218e9878599..58218c0272bd 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -150,6 +150,12 @@ title: Reinforcement learning training with DDPO title: Methods title: Training +- sections: + - local: quantization/overview + title: Getting Started + - local: quantization/bitsandbytes + title: bitsandbytes + title: Quantization Methods - sections: - local: optimization/fp16 title: Speed up inference @@ -209,6 +215,8 @@ title: Logging - local: api/outputs title: Outputs + - local: api/quantization + title: Quantization title: Main Classes - isExpanded: false sections: diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md index 4cde7a111ae6..f0f4fd37e6d5 100644 --- a/docs/source/en/api/pipelines/cogvideox.md +++ b/docs/source/en/api/pipelines/cogvideox.md @@ -36,6 +36,10 @@ There are two models available that can be used with the text-to-video and video There is one model available that can be used with the image-to-video CogVideoX pipeline: - [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`. +There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team): +- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`. +- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`. + ## Inference Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency. @@ -118,6 +122,12 @@ It is also worth noting that torchao quantization is fully compatible with [torc - all - __call__ +## CogVideoXFunControlPipeline + +[[autodoc]] CogVideoXFunControlPipeline + - all + - __call__ + ## CogVideoXPipelineOutput [[autodoc]] pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput diff --git a/docs/source/en/api/pipelines/controlnet_flux.md b/docs/source/en/api/pipelines/controlnet_flux.md index f63885b4d42c..82454ae5e930 100644 --- a/docs/source/en/api/pipelines/controlnet_flux.md +++ b/docs/source/en/api/pipelines/controlnet_flux.md @@ -1,4 +1,4 @@ - + +# Quantization + +Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index). + +Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class. + + + +Learn how to quantize models in the [Quantization](../quantization/overview) guide. + + + + +## BitsAndBytesConfig + +[[autodoc]] BitsAndBytesConfig + +## DiffusersQuantizer + +[[autodoc]] quantizers.base.DiffusersQuantizer diff --git a/docs/source/en/quantization/bitsandbytes.md b/docs/source/en/quantization/bitsandbytes.md new file mode 100644 index 000000000000..118511b75d50 --- /dev/null +++ b/docs/source/en/quantization/bitsandbytes.md @@ -0,0 +1,260 @@ + + +# bitsandbytes + +[bitsandbytes](https://huggingface.co/docs/bitsandbytes/index) is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance. + +4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs. + + +To use bitsandbytes, make sure you have the following libraries installed: + +```bash +pip install diffusers transformers accelerate bitsandbytes -U +``` + +Now you can quantize a model by passing a [`BitsAndBytesConfig`] to [`~ModelMixin.from_pretrained`]. This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers. + + + + +Quantizing a model in 8-bit halves the memory-usage: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) + +model_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config +) +``` + +By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_8bit=True) + +model_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.float32 +) +model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype +``` + +Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`]. + + + + +Quantizing a model in 4-bit reduces your memory-usage by 4x: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True) + +model_4bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config +) +``` + +By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True) + +model_4bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, + torch_dtype=torch.float32 +) +model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype +``` + +Call [`~ModelMixin.push_to_hub`] after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`]. + + + + + + +Training with 8-bit and 4-bit weights are only supported for training *extra* parameters. + + + +Check your memory footprint with the `get_memory_footprint` method: + +```py +print(model.get_memory_footprint()) +``` + +Quantized models can be loaded from the [`~ModelMixin.from_pretrained`] method without needing to specify the `quantization_config` parameters: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True) + +model_4bit = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer" +) +``` + +## 8-bit (LLM.int8() algorithm) + + + +Learn more about the details of 8-bit quantization in this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration)! + + + +This section explores some of the specific features of 8-bit models, such as outlier thresholds and skipping module conversion. + +### Outlier threshold + +An "outlier" is a hidden state value greater than a certain threshold, and these values are computed in fp16. While the values are usually normally distributed ([-3.5, 3.5]), this distribution can be very different for large models ([-60, 6] or [6, 60]). 8-bit quantization works well for values ~5, but beyond that, there is a significant performance penalty. A good default threshold value is 6, but a lower threshold may be needed for more unstable models (small models or finetuning). + +To find the best threshold for your model, we recommend experimenting with the `llm_int8_threshold` parameter in [`BitsAndBytesConfig`]: + +```py +from diffusers import FluxTransformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig( + load_in_8bit=True, llm_int8_threshold=10, +) + +model_8bit = FluxTransformer2DModel.from_pretrained( + "black-forest-labs/FLUX.1-dev", + subfolder="transformer", + quantization_config=quantization_config, +) +``` + +### Skip module conversion + +For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), the `proj_out` module can be skipped using the `llm_int8_skip_modules` parameter in [`BitsAndBytesConfig`]: + +```py +from diffusers import SD3Transformer2DModel, BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig( + load_in_8bit=True, llm_int8_skip_modules=["proj_out"], +) + +model_8bit = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=quantization_config, +) +``` + + +## 4-bit (QLoRA algorithm) + + + +Learn more about its details in this [blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes). + + + +This section explores some of the specific features of 4-bit models, such as changing the compute data type, using the Normal Float 4 (NF4) data type, and using nested quantization. + + +### Compute data type + +To speedup computation, you can change the data type from float32 (the default value) to bf16 using the `bnb_4bit_compute_dtype` parameter in [`BitsAndBytesConfig`]: + +```py +import torch +from diffusers import BitsAndBytesConfig + +quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) +``` + +### Normal Float 4 (NF4) + +NF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the `bnb_4bit_quant_type` parameter in the [`BitsAndBytesConfig`]: + +```py +from diffusers import BitsAndBytesConfig + +nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", +) + +model_nf4 = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=nf4_config, +) +``` + +For inference, the `bnb_4bit_quant_type` does not have a huge impact on performance. However, to remain consistent with the model weights, you should use the `bnb_4bit_compute_dtype` and `torch_dtype` values. + +### Nested quantization + +Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter. + +```py +from diffusers import BitsAndBytesConfig + +double_quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, +) + +double_quant_model = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=double_quant_config, +) +``` + +## Dequantizing `bitsandbytes` models + +Once quantized, you can dequantize the model to the original precision but this might result in a small quality loss of the model. Make sure you have enough GPU RAM to fit the dequantized model. + +```python +from diffusers import BitsAndBytesConfig + +double_quant_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, +) + +double_quant_model = SD3Transformer2DModel.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + subfolder="transformer", + quantization_config=double_quant_config, +) +model.dequantize() +``` + +## Resources + +* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4) +* [Training](https://gist.github.com/sayakpaul/05afd428bc089b47af7c016e42004527) \ No newline at end of file diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md new file mode 100644 index 000000000000..d8adbc85a259 --- /dev/null +++ b/docs/source/en/quantization/overview.md @@ -0,0 +1,35 @@ + + +# Quantization + +Quantization techniques focus on representing data with less information while also trying to not lose too much accuracy. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory-usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits. + + + +Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method. + + + + + +If you are new to the quantization field, we recommend you to check out these beginner-friendly courses about quantization in collaboration with DeepLearning.AI: + +* [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/) +* [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth/) + + + +## When to use what? + +This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. \ No newline at end of file diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md new file mode 100644 index 000000000000..8817431bede5 --- /dev/null +++ b/examples/advanced_diffusion_training/README_flux.md @@ -0,0 +1,351 @@ +# Advanced diffusion training examples + +## Train Dreambooth LoRA with Flux.1 Dev +> [!TIP] +> 💡 This example follows some of the techniques and recommended practices covered in the community derived guide we made for SDXL training: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). +> As many of these are architecture agnostic & generally relevant to fine-tuning of diffusion models we suggest to take a look 🤗 + +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text-to-image models like flux, stable diffusion given just a few(3~5) images of a subject. + +LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen* +In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages: +- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114) +- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable. +- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter. +[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in +the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. + +The `train_dreambooth_lora_flux_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_flux.py`, with +advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl), +[ostris](https://x.com/ostrisai):[ai-toolkit](https://github.com/ostris/ai-toolkit), [bghira](https://github.com/bghira):[SimpleTuner](https://github.com/bghira/SimpleTuner), [Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️ + +> [!NOTE] +> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳 +> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora) + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/advanced_diffusion_training` folder and run +```bash +pip install -r requirements.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell e.g. a notebook + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. +Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. + +### Target Modules +When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them. +More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore +applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string +the exact modules for LoRA training. Here are some examples of target modules you can provide: +- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"` +- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"` +- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"` +> [!NOTE] +> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string: +> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k` +> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k` +> [!NOTE] +> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights. + +### Pivotal Tuning (and more) +**Training with text encoder(s)** + +Alongside the Transformer, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization +available with `train_dreambooth_lora_flux_advanced.py`, in the advanced script **pivotal tuning** is also supported. +[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning - +we insert new tokens into the text encoders of the model, instead of reusing existing ones. +We then optimize the newly-inserted token embeddings to represent the new concept. + +To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`). +Please keep the following points in mind: + +* Flux uses two text encoders - [CLIP](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder) & [T5](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#diffusers.FluxPipeline.text_encoder_2) , by default `--train_text_encoder_ti` performs pivotal tuning for the **CLIP** encoder only. +To activate pivotal tuning for both encoders, add the flag `--enable_t5_ti`. +* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory. +* **pure textual inversion** - to support the full range from pivotal tuning to textual inversion we introduce `--train_transformer_frac` which controls the amount of epochs the transformer LoRA layers are trained. By default, `--train_transformer_frac==1`, to trigger a textual inversion run set `--train_transformer_frac==0`. Values between 0 and 1 are supported as well, and we welcome the community to experiment w/ different settings and share the results! +* **token initializer** - similar to the original textual inversion work, you can specify a concept of your choosing as the starting point for training. By default, when enabling `--train_text_encoder_ti`, the new inserted tokens are initialized randomly. You can specify a token in `--initializer_concept` such that the starting point for the trained embeddings will be the embeddings associated with your chosen `--initializer_concept`. + +## Training examples + +Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon. + +Let's first download it locally: + +```python +from huggingface_hub import snapshot_download + +local_dir = "./3d_icon" +snapshot_download( + "LinoyTsaban/3d_icon", + local_dir=local_dir, repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` + +Let's review some of the advanced features we're going to be using for this example: +- **custom captions**: +To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by +```bash +pip install datasets +``` + +Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt") + +``` +--dataset_name=./3d_icon +--caption_column=prompt +``` + +You can also load a dataset straight from by specifying it's name in `dataset_name`. +Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset. + +- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer +- **pivotal tuning** + +### Example #1: Pivotal tuning +**Now, we can launch training:** + +```bash +export MODEL_NAME="black-forest-labs/FLUX.1-dev" +export DATASET_NAME="./3d_icon" +export OUTPUT_DIR="3d-icon-Flux-LoRA" + +accelerate launch train_dreambooth_lora_flux_advanced.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --instance_prompt="3d icon in the style of TOK" \ + --output_dir=$OUTPUT_DIR \ + --caption_column="prompt" \ + --mixed_precision="bf16" \ + --resolution=1024 \ + --train_batch_size=1 \ + --repeats=1 \ + --report_to="wandb"\ + --gradient_accumulation_steps=1 \ + --gradient_checkpointing \ + --learning_rate=1.0 \ + --text_encoder_lr=1.0 \ + --optimizer="prodigy"\ + --train_text_encoder_ti\ + --train_text_encoder_ti_frac=0.5\ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --rank=8 \ + --max_train_steps=700 \ + --checkpointing_steps=2000 \ + --seed="0" \ + --push_to_hub +``` + +To better track our training experiments, we're using the following flags in the command above: + +* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. +* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. + +Our experiments were conducted on a single 40GB A100 GPU. + +### Example #2: Pivotal tuning with T5 +Now let's try that with T5 as well, so instead of only optimizing the CLIP embeddings associated with newly inserted tokens, we'll optimize +the T5 embeddings as well. We can do this by simply adding `--enable_t5_ti` to the previous configuration: +```bash +export MODEL_NAME="black-forest-labs/FLUX.1-dev" +export DATASET_NAME="./3d_icon" +export OUTPUT_DIR="3d-icon-Flux-LoRA" + +accelerate launch train_dreambooth_lora_flux_advanced.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --instance_prompt="3d icon in the style of TOK" \ + --output_dir=$OUTPUT_DIR \ + --caption_column="prompt" \ + --mixed_precision="bf16" \ + --resolution=1024 \ + --train_batch_size=1 \ + --repeats=1 \ + --report_to="wandb"\ + --gradient_accumulation_steps=1 \ + --gradient_checkpointing \ + --learning_rate=1.0 \ + --text_encoder_lr=1.0 \ + --optimizer="prodigy"\ + --train_text_encoder_ti\ + --enable_t5_ti\ + --train_text_encoder_ti_frac=0.5\ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --rank=8 \ + --max_train_steps=700 \ + --checkpointing_steps=2000 \ + --seed="0" \ + --push_to_hub +``` + +### Example #3: Textual Inversion +To explore a pure textual inversion - i.e. only optimizing the text embeddings w/o training transformer LoRA layers, we +can set the value for `--train_transformer_frac` - which is responsible for the percent of epochs in which the transformer is +trained. By setting `--train_transformer_frac == 0` and enabling `--train_text_encoder_ti` we trigger a textual inversion train +run. +```bash +export MODEL_NAME="black-forest-labs/FLUX.1-dev" +export DATASET_NAME="./3d_icon" +export OUTPUT_DIR="3d-icon-Flux-LoRA" + +accelerate launch train_dreambooth_lora_flux_advanced.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_NAME \ + --instance_prompt="3d icon in the style of TOK" \ + --output_dir=$OUTPUT_DIR \ + --caption_column="prompt" \ + --mixed_precision="bf16" \ + --resolution=1024 \ + --train_batch_size=1 \ + --repeats=1 \ + --report_to="wandb"\ + --gradient_accumulation_steps=1 \ + --gradient_checkpointing \ + --learning_rate=1.0 \ + --text_encoder_lr=1.0 \ + --optimizer="prodigy"\ + --train_text_encoder_ti\ + --enable_t5_ti\ + --train_text_encoder_ti_frac=0.5\ + --train_transformer_frac=0\ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --rank=8 \ + --max_train_steps=700 \ + --checkpointing_steps=2000 \ + --seed="0" \ + --push_to_hub +``` +### Inference - pivotal tuning + +Once training is done, we can perform inference like so: +1. starting with loading the transformer lora weights +```python +import torch +from huggingface_hub import hf_hub_download, upload_file +from diffusers import AutoPipelineForText2Image +from safetensors.torch import load_file + +username = "linoyts" +repo_id = f"{username}/3d-icon-Flux-LoRA" + +pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda') + + +pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors") +``` +2. now we load the pivotal tuning embeddings +> [!NOTE] #1 if `--enable_t5_ti` wasn't passed, we only load the embeddings to the CLIP encoder. + +> [!NOTE] #2 the number of tokens (i.e. ,...,) is either determined by `--num_new_tokens_per_abstraction` or by `--initializer_concept`. Make sure to update inference code accordingly :) +```python +text_encoders = [pipe.text_encoder, pipe.text_encoder_2] +tokenizers = [pipe.tokenizer, pipe.tokenizer_2] + +embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model") + +state_dict = load_file(embedding_path) +# load embeddings of text_encoder 1 (CLIP ViT-L/14) +pipe.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) +# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti` +pipe.load_textual_inversion(state_dict["t5"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) +``` + +3. let's generate images + +```python +instance_token = "" +prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}" + +image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0] +image.save("llama.png") +``` + +### Inference - pure textual inversion +In this case, we don't load transformer layers as before, since we only optimize the text embeddings. The output of a textual inversion train run is a +`.safetensors` file containing the trained embeddings for the new tokens either for the CLIP encoder, or for both encoders (CLIP and T5) + +1. starting with loading the embeddings. +💡note that here too, if you didn't enable `--enable_t5_ti`, you only load the embeddings to the CLIP encoder + +```python +import torch +from huggingface_hub import hf_hub_download, upload_file +from diffusers import AutoPipelineForText2Image +from safetensors.torch import load_file + +username = "linoyts" +repo_id = f"{username}/3d-icon-Flux-LoRA" + +pipe = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda') + +text_encoders = [pipe.text_encoder, pipe.text_encoder_2] +tokenizers = [pipe.tokenizer, pipe.tokenizer_2] + +embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-Flux-LoRA_emb.safetensors", repo_type="model") + +state_dict = load_file(embedding_path) +# load embeddings of text_encoder 1 (CLIP ViT-L/14) +pipe.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) +# load embeddings of text_encoder 2 (T5 XXL) - ignore this line if you didn't enable `--enable_t5_ti` +pipe.load_textual_inversion(state_dict["t5"], token=["", ""], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2) +``` +2. let's generate images + +```python +instance_token = "" +prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}" + +image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0] +image.save("llama.png") +``` + +### Comfy UI / AUTOMATIC1111 Inference +The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats! + +**AUTOMATIC1111 / SD.Next** \ +In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time. +- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. +- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory. + +You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls `. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`. + +**ComfyUI** \ +In ComfyUI we will load a LoRA and a textual embedding at the same time. +- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/) +- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/). diff --git a/examples/advanced_diffusion_training/requirements_flux.txt b/examples/advanced_diffusion_training/requirements_flux.txt new file mode 100644 index 000000000000..dbc124ff6526 --- /dev/null +++ b/examples/advanced_diffusion_training/requirements_flux.txt @@ -0,0 +1,8 @@ +accelerate>=0.31.0 +torchvision +transformers>=4.41.2 +ftfy +tensorboard +Jinja2 +peft>=0.11.1 +sentencepiece \ No newline at end of file diff --git a/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py new file mode 100644 index 000000000000..e29c99821303 --- /dev/null +++ b/examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py @@ -0,0 +1,283 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import tempfile + +import safetensors + + +sys.path.append("..") +from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402 + + +logging.basicConfig(level=logging.DEBUG) + +logger = logging.getLogger() +stream_handler = logging.StreamHandler(sys.stdout) +logger.addHandler(stream_handler) + + +class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate): + instance_data_dir = "docs/source/en/imgs" + instance_prompt = "photo" + pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" + script_path = "examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py" + + def test_dreambooth_lora_flux(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_text_encoder_flux(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --train_text_encoder + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + starts_with_expected_prefix = all( + (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys() + ) + self.assertTrue(starts_with_expected_prefix) + + def test_dreambooth_lora_pivotal_tuning_flux_clip(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --train_text_encoder_ti + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + # make sure embeddings were also saved + self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # make sure the state_dict has the correct naming in the parameters. + textual_inversion_state_dict = safetensors.torch.load_file( + os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors") + ) + is_clip = all("clip_l" in k for k in textual_inversion_state_dict.keys()) + self.assertTrue(is_clip) + + # when performing pivotal tuning, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --train_text_encoder_ti + --enable_t5_ti + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + # make sure embeddings were also saved + self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # make sure the state_dict has the correct naming in the parameters. + textual_inversion_state_dict = safetensors.torch.load_file( + os.path.join(tmpdir, f"{os.path.basename(tmpdir)}_emb.safetensors") + ) + is_te = all(("clip_l" in k or "t5" in k) for k in textual_inversion_state_dict.keys()) + self.assertTrue(is_te) + + # when performing pivotal tuning, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_latent_caching(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=6 + --checkpoints_total_limit=2 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=4 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"}) + + resume_run_args = f""" + {self.script_path} + --pretrained_model_name_or_path={self.pretrained_model_name_or_path} + --instance_data_dir={self.instance_data_dir} + --output_dir={tmpdir} + --instance_prompt={self.instance_prompt} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --max_train_steps=8 + --checkpointing_steps=2 + --resume_from_checkpoint=checkpoint-4 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + resume_run_args) + + self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"}) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py new file mode 100644 index 000000000000..e3e46ead8ee3 --- /dev/null +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -0,0 +1,2463 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import itertools +import logging +import math +import os +import random +import re +import shutil +from contextlib import nullcontext +from pathlib import Path +from typing import List, Optional, Union + +import numpy as np +import torch +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from safetensors.torch import save_file +from torch.utils.data import Dataset +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast + +import diffusers +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + FluxPipeline, + FluxTransformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + _set_state_dict_into_text_encoder, + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + free_memory, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.32.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model: str = None, + train_text_encoder=False, + train_text_encoder_ti=False, + enable_t5_ti=False, + pure_textual_inversion=False, + token_abstraction_dict=None, + instance_prompt=None, + validation_prompt=None, + repo_folder=None, +): + widget_dict = [] + trigger_str = f"You should use {instance_prompt} to trigger the image generation." + + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + widget_dict.append( + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} + ) + diffusers_load_lora = "" + diffusers_imports_pivotal = "" + diffusers_example_pivotal = "" + if not pure_textual_inversion: + diffusers_load_lora = ( + f"""pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')""" + ) + if train_text_encoder_ti: + embeddings_filename = f"{repo_folder}_emb" + ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt)) + trigger_str = ( + "To trigger image generation of trained concept(or concepts) replace each concept identifier " + "in you prompt with the new inserted tokens:\n" + ) + diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download + from safetensors.torch import load_file + """ + if enable_t5_ti: + diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model") + state_dict = load_file(embedding_path) + pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) + pipeline.load_textual_inversion(state_dict["t5"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) + """ + else: + diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model") + state_dict = load_file(embedding_path) + pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) + """ + if token_abstraction_dict: + for key, value in token_abstraction_dict.items(): + tokens = "".join(value) + trigger_str += f""" + to trigger concept `{key}` → use `{tokens}` in your prompt \n + """ + + model_description = f""" +# Flux DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md). + +Was LoRA for the text encoder enabled? {train_text_encoder}. + +Pivotal tuning was enabled: {train_text_encoder_ti}. + +## Trigger words + +{trigger_str} + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) + +```py +from diffusers import AutoPipelineForText2Image +import torch +{diffusers_imports_pivotal} +pipeline = AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to('cuda') +{diffusers_load_lora} +{diffusers_example_pivotal} +image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] +``` + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "flux", + "flux-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def load_text_encoders(class_one, class_two): + text_encoder_one = class_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_two = class_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + return text_encoder_one, text_encoder_two + + +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + torch_dtype, + is_final_validation=False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + autocast_ctx = nullcontext() + + with autocast_ctx: + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + free_memory() + + return images + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--token_abstraction", + type=str, + default="TOK", + help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, " + "captions - e.g. TOK. To use multiple identifiers, please specify them in a comma separated string - e.g. " + "'TOK,TOK2,TOK3' etc.", + ) + + parser.add_argument( + "--num_new_tokens_per_abstraction", + type=int, + default=None, + help="number of new tokens inserted to the tokenizers per token_abstraction identifier when " + "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " + "tokens - ", + ) + parser.add_argument( + "--initializer_concept", + type=str, + default=None, + help="the concept to use to initialize the new inserted tokens when training with " + "--train_text_encoder_ti = True. By default, new tokens () are initialized with random value. " + "Alternatively, you could specify a different word/words whos value will be used as the starting point for the new inserted tokens. " + "--num_new_tokens_per_abstraction is ignored when initializer_concept is provided", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=512, + help="Maximum sequence length to use with with the T5 text encoder", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_text_encoder_ti", + action="store_true", + help=("Whether to use pivotal tuning / textual inversion"), + ) + parser.add_argument( + "--enable_t5_ti", + action="store_true", + help=( + "Whether to use pivotal tuning / textual inversion for the T5 encoder as well (in addition to CLIP encoder)" + ), + ) + + parser.add_argument( + "--train_text_encoder_ti_frac", + type=float, + default=0.5, + help=("The percentage of epochs to perform textual inversion"), + ) + + parser.add_argument( + "--train_text_encoder_frac", + type=float, + default=1.0, + help=("The percentage of epochs to perform text encoder tuning"), + ) + parser.add_argument( + "--train_transformer_frac", + type=float, + default=1.0, + help=("The percentage of epochs to perform transformer tuning"), + ) + + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for transformer params" + ) + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + "The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. " + 'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md' + ), + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + if args.train_text_encoder and args.train_text_encoder_ti: + raise ValueError( + "Specify only one of `--train_text_encoder` or `--train_text_encoder_ti. " + "For full LoRA text encoder training check --train_text_encoder, for textual " + "inversion training check `--train_text_encoder_ti`" + ) + if args.train_transformer_frac < 1 and not args.train_text_encoder_ti: + raise ValueError( + "--train_transformer_frac must be == 1 if text_encoder training / textual inversion is not enabled." + ) + if args.train_transformer_frac < 1 and args.train_text_encoder_ti_frac < 1: + raise ValueError( + "--train_transformer_frac and --train_text_encoder_ti_frac are identical and smaller than 1. " + "This contradicts with --max_train_steps, please specify different values or set both to 1." + ) + if args.enable_t5_ti and not args.train_text_encoder_ti: + logger.warning("You need not use --enable_t5_ti without --train_text_encoder_ti.") + + if args.train_text_encoder_ti and args.initializer_concept and args.num_new_tokens_per_abstraction: + logger.warning( + "When specifying --initializer_concept, the number of tokens per abstraction is detrimned " + "by the initializer token. --num_new_tokens_per_abstraction will be ignored" + ) + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + if args.class_data_dir is not None: + logger.warning("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + logger.warning("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +# Modified from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py +class TokenEmbeddingsHandler: + def __init__(self, text_encoders, tokenizers): + self.text_encoders = text_encoders + self.tokenizers = tokenizers + + self.train_ids: Optional[torch.Tensor] = None + self.train_ids_t5: Optional[torch.Tensor] = None + self.inserting_toks: Optional[List[str]] = None + self.embeddings_settings = {} + + def initialize_new_tokens(self, inserting_toks: List[str]): + idx = 0 + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." + assert all( + isinstance(tok, str) for tok in inserting_toks + ), "All elements in inserting_toks should be strings." + + self.inserting_toks = inserting_toks + special_tokens_dict = {"additional_special_tokens": self.inserting_toks} + tokenizer.add_special_tokens(special_tokens_dict) + # Resize the token embeddings as we are adding new special tokens to the tokenizer + text_encoder.resize_token_embeddings(len(tokenizer)) + + # Convert the token abstractions to ids + if idx == 0: + self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) + else: + self.train_ids_t5 = tokenizer.convert_tokens_to_ids(self.inserting_toks) + + # random initialization of new tokens + embeds = ( + text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens + ) + std_token_embedding = embeds.weight.data.std() + + logger.info(f"{idx} text encoder's std_token_embedding: {std_token_embedding}") + + train_ids = self.train_ids if idx == 0 else self.train_ids_t5 + # if initializer_concept are not provided, token embeddings are initialized randomly + if args.initializer_concept is None: + hidden_size = ( + text_encoder.text_model.config.hidden_size if idx == 0 else text_encoder.encoder.config.hidden_size + ) + embeds.weight.data[train_ids] = ( + torch.randn(len(train_ids), hidden_size).to(device=self.device).to(dtype=self.dtype) + * std_token_embedding + ) + else: + # Convert the initializer_token, placeholder_token to ids + initializer_token_ids = tokenizer.encode(args.initializer_concept, add_special_tokens=False) + for token_idx, token_id in enumerate(train_ids): + embeds.weight.data[token_id] = (embeds.weight.data)[ + initializer_token_ids[token_idx % len(initializer_token_ids)] + ].clone() + + self.embeddings_settings[f"original_embeddings_{idx}"] = embeds.weight.data.clone() + self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding + + # makes sure we don't update any embedding weights besides the newly added token + index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) + index_no_updates[train_ids] = False + + self.embeddings_settings[f"index_no_updates_{idx}"] = index_no_updates + + logger.info(self.embeddings_settings[f"index_no_updates_{idx}"].shape) + + idx += 1 + + def save_embeddings(self, file_path: str): + assert self.train_ids is not None, "Initialize new tokens before saving embeddings." + tensors = {} + # text_encoder_one, idx==0 - CLIP ViT-L/14, text_encoder_two, idx==1 - T5 xxl + idx_to_text_encoder_name = {0: "clip_l", 1: "t5"} + for idx, text_encoder in enumerate(self.text_encoders): + train_ids = self.train_ids if idx == 0 else self.train_ids_t5 + embeds = ( + text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens + ) + assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same." + new_token_embeddings = embeds.weight.data[train_ids] + + # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), + # Note: When loading with diffusers, any name can work - simply specify in inference + tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings + # tensors[f"text_encoders_{idx}"] = new_token_embeddings + + save_file(tensors, file_path) + + @property + def dtype(self): + return self.text_encoders[0].dtype + + @property + def device(self): + return self.text_encoders[0].device + + @torch.no_grad() + def retract_embeddings(self): + for idx, text_encoder in enumerate(self.text_encoders): + embeds = ( + text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens + ) + index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] + embeds.weight.data[index_no_updates] = ( + self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] + .to(device=text_encoder.device) + .to(dtype=text_encoder.dtype) + ) + + # for the parts that were updated, we need to normalize them + # to have the same std as before + std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"] + + index_updates = ~index_no_updates + new_embeddings = embeds.weight.data[index_updates] + off_ratio = std_token_embedding / new_embeddings.std() + + new_embeddings = new_embeddings * (off_ratio**0.1) + embeds.weight.data[index_updates] = new_embeddings + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + train_text_encoder_ti, + token_abstraction_dict=None, # token mapping for textual inversion + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + self.token_abstraction_dict = token_abstraction_dict + self.train_text_encoder_ti = train_text_encoder_ti + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + + self.pixel_values = [] + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + for image in self.instance_images: + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + image = train_transforms(image) + self.pixel_values.append(image) + + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.pixel_values[index % self.num_instance_images] + example["instance_images"] = instance_image + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + if self.train_text_encoder_ti: + # replace instances of --token_abstraction in caption with the new tokens: "" etc. + for token_abs, token_replacement in self.token_abstraction_dict.items(): + caption = caption.replace(token_abs, "".join(token_replacement)) + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # the given instance prompt is used for all images + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=False): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + add_special_tokens=add_special_tokens, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + +def _get_t5_prompt_embeds( + text_encoder, + tokenizer, + max_sequence_length=512, + prompt=None, + num_images_per_prompt=1, + device=None, + text_input_ids=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + +def _get_clip_prompt_embeds( + text_encoder, + tokenizer, + prompt: str, + device=None, + text_input_ids=None, + num_images_per_prompt: int = 1, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + else: + if text_input_ids is None: + raise ValueError("text_input_ids must be provided when the tokenizer is not specified") + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + +def encode_prompt( + text_encoders, + tokenizers, + prompt: str, + max_sequence_length, + device=None, + num_images_per_prompt: int = 1, + text_input_ids_list=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + dtype = text_encoders[0].dtype + + pooled_prompt_embeds = _get_clip_prompt_embeds( + text_encoder=text_encoders[0], + tokenizer=tokenizers[0], + prompt=prompt, + device=device if device is not None else text_encoders[0].device, + num_images_per_prompt=num_images_per_prompt, + text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None, + ) + + prompt_embeds = _get_t5_prompt_embeds( + text_encoder=text_encoders[1], + tokenizer=tokenizers[1], + max_sequence_length=max_sequence_length, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device if device is not None else text_encoders[1].device, + text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None, + ) + + text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + +# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer: +# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95 +class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + with torch.no_grad(): + # create weights for timesteps + num_timesteps = 1000 + + # generate the multiplier based on cosmap loss weighing + # this is only used on linear timesteps for now + + # cosine map weighing is higher in the middle and lower at the ends + # bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2 + # cosmap_weighing = 2 / (math.pi * bot) + + # sigma sqrt weighing is significantly higher at the end and lower at the beginning + sigma_sqrt_weighing = (self.sigmas**-2.0).float() + # clip at 1e4 (1e6 is too high) + sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4) + # bring to a mean of 1 + sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean() + + # Create linear timesteps from 1000 to 0 + timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu") + + self.linear_timesteps = timesteps + # self.linear_timesteps_weights = cosmap_weighing + self.linear_timesteps_weights = sigma_sqrt_weighing + + # self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu') + pass + + def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor: + # Get the indices of the timesteps + step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps] + + # Get the weights for the timesteps + weights = self.linear_timesteps_weights[step_indices].flatten() + + return weights + + def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: + sigmas = self.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = self.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + + return sigma + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + ## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578 + ## Add noise according to flow matching. + ## zt = (1 - texp) * x + texp * z1 + + # sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + # noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # timestep needs to be in [0, 1], we store them in [0, 1000] + # noisy_sample = (1 - timestep) * latent + timestep * noise + t_01 = (timesteps / 1000).to(original_samples.device) + noisy_model_input = (1 - t_01) * original_samples + t_01 * noise + + # n_dim = original_samples.ndim + # sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) + # noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise + return noisy_model_input + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + return sample + + def set_train_timesteps(self, num_timesteps, device, linear=False): + if linear: + timesteps = torch.linspace(1000, 0, num_timesteps, device=device) + self.timesteps = timesteps + return timesteps + else: + # distribute them closer to center. Inference distributes them as a bias toward first + # Generate values from 0 to 1 + t = torch.sigmoid(torch.randn((num_timesteps,), device=device)) + + # Scale and reverse the values to go from 1000 to 0 + timesteps = (1 - t) * 1000 + + # Sort the timesteps in descending order + timesteps, _ = torch.sort(timesteps, descending=True) + + self.timesteps = timesteps.to(device=device) + + return timesteps + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available() + torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + revision=args.revision, + variant=args.variant, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + free_memory() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + model_id = args.hub_model_id or Path(args.output_dir).name + repo_id = None + if args.push_to_hub: + repo_id = create_repo( + repo_id=model_id, + exist_ok=True, + ).repo_id + + # Load the tokenizers + tokenizer_one = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + tokenizer_two = T5TokenizerFast.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + transformer = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + ) + + if args.train_text_encoder_ti: + # we parse the provided token identifier (or identifiers) into a list. s.t. - "TOK" -> ["TOK"], "TOK, + # TOK2" -> ["TOK", "TOK2"] etc. + token_abstraction_list = [place_holder.strip() for place_holder in re.split(r",\s*", args.token_abstraction)] + logger.info(f"list of token identifiers: {token_abstraction_list}") + + if args.initializer_concept is None: + num_new_tokens_per_abstraction = ( + 2 if args.num_new_tokens_per_abstraction is None else args.num_new_tokens_per_abstraction + ) + # if args.initializer_concept is provided, we ignore args.num_new_tokens_per_abstraction + else: + token_ids = tokenizer_one.encode(args.initializer_concept, add_special_tokens=False) + num_new_tokens_per_abstraction = len(token_ids) + if args.enable_t5_ti: + token_ids_t5 = tokenizer_two.encode(args.initializer_concept, add_special_tokens=False) + num_new_tokens_per_abstraction = max(len(token_ids), len(token_ids_t5)) + logger.info( + f"initializer_concept: {args.initializer_concept}, num_new_tokens_per_abstraction: {num_new_tokens_per_abstraction}" + ) + + token_abstraction_dict = {} + token_idx = 0 + for i, token in enumerate(token_abstraction_list): + token_abstraction_dict[token] = [f"" for j in range(num_new_tokens_per_abstraction)] + token_idx += num_new_tokens_per_abstraction - 1 + + # replace instances of --token_abstraction in --instance_prompt with the new tokens: "" etc. + for token_abs, token_replacement in token_abstraction_dict.items(): + new_instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement)) + if args.instance_prompt == new_instance_prompt: + logger.warning( + "Note! the instance prompt provided in --instance_prompt does not include the token abstraction specified " + "--token_abstraction. This may lead to incorrect optimization of text embeddings during pivotal tuning" + ) + args.instance_prompt = new_instance_prompt + if args.with_prior_preservation: + args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement)) + if args.validation_prompt: + args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement)) + + # initialize the new tokens for textual inversion + text_encoders = [text_encoder_one, text_encoder_two] if args.enable_t5_ti else [text_encoder_one] + tokenizers = [tokenizer_one, tokenizer_two] if args.enable_t5_ti else [tokenizer_one] + embedding_handler = TokenEmbeddingsHandler(text_encoders, tokenizers) + inserting_toks = [] + for new_tok in token_abstraction_dict.values(): + inserting_toks.extend(new_tok) + embedding_handler.initialize_new_tokens(inserting_toks=inserting_toks) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + vae.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder_one.gradient_checkpointing_enable() + + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + target_modules = [ + "attn.to_k", + "attn.to_q", + "attn.to_v", + "attn.to_out.0", + "attn.add_k_proj", + "attn.add_q_proj", + "attn.add_v_proj", + "attn.to_add_out", + "ff.net.0.proj", + "ff.net.2", + "ff_context.net.0.proj", + "ff_context.net.2", + ] + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=target_modules, + ) + transformer.add_adapter(transformer_lora_config) + + if args.train_text_encoder: + text_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + ) + text_encoder_one.add_adapter(text_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + text_encoder_one_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + elif isinstance(model, type(unwrap_model(text_encoder_one))): + if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers + text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + FluxPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + ) + if args.train_text_encoder_ti: + embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors") + + def load_model_hook(models, input_dir): + transformer_ = None + text_encoder_one_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict = FluxPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + if args.train_text_encoder: + models.extend([text_encoder_one_]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + if args.train_text_encoder: + models.extend([text_encoder_one]) + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + if args.train_text_encoder: + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + # if we use textual inversion, we freeze all parameters except for the token embeddings + # in text encoder + elif args.train_text_encoder_ti: + text_lora_parameters_one = [] # CLIP + for name, param in text_encoder_one.named_parameters(): + if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param.data = param.to(dtype=torch.float32) + param.requires_grad = True + text_lora_parameters_one.append(param) + else: + param.requires_grad = False + if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well + text_lora_parameters_two = [] + for name, param in text_encoder_two.named_parameters(): + if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param.data = param.to(dtype=torch.float32) + param.requires_grad = True + text_lora_parameters_two.append(param) + else: + param.requires_grad = False + + # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training + freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) + + # if --train_text_encoder_ti and train_transformer_frac == 0 where essentially performing textual inversion + # and not training transformer LoRA layers + pure_textual_inversion = args.train_text_encoder_ti and args.train_transformer_frac == 0 + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + if not freeze_text_encoder: + # different learning rate for text encoder and transformer + text_parameters_one_with_lr = { + "params": text_lora_parameters_one, + "weight_decay": args.adam_weight_decay_text_encoder + if args.adam_weight_decay_text_encoder + else args.adam_weight_decay, + "lr": args.text_encoder_lr, + } + if not args.enable_t5_ti: + # pure textual inversion - only clip + if pure_textual_inversion: + params_to_optimize = [ + text_parameters_one_with_lr, + ] + te_idx = 0 + else: # regular te training or regular pivotal for clip + params_to_optimize = [ + transformer_parameters_with_lr, + text_parameters_one_with_lr, + ] + te_idx = 1 + elif args.enable_t5_ti: + # pivotal tuning of clip & t5 + text_parameters_two_with_lr = { + "params": text_lora_parameters_two, + "weight_decay": args.adam_weight_decay_text_encoder + if args.adam_weight_decay_text_encoder + else args.adam_weight_decay, + "lr": args.text_encoder_lr, + } + # pure textual inversion - only clip & t5 + if pure_textual_inversion: + params_to_optimize = [text_parameters_one_with_lr, text_parameters_two_with_lr] + te_idx = 0 + else: # regular pivotal tuning of clip & t5 + params_to_optimize = [ + transformer_parameters_with_lr, + text_parameters_one_with_lr, + text_parameters_two_with_lr, + ] + te_idx = 1 + else: + params_to_optimize = [ + transformer_parameters_with_lr, + ] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if not freeze_text_encoder and args.text_encoder_lr: + logger.warning( + f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters to be + # --learning_rate + + params_to_optimize[te_idx]["lr"] = args.learning_rate + params_to_optimize[-1]["lr"] = args.learning_rate + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + train_text_encoder_ti=args.train_text_encoder_ti, + token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + if freeze_text_encoder: + tokenizers = [tokenizer_one, tokenizer_two] + text_encoders = [text_encoder_one, text_encoder_two] + + def compute_text_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders, tokenizers, prompt, args.max_sequence_length + ) + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + text_ids = text_ids.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds, text_ids + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if freeze_text_encoder and not train_dataset.custom_instance_prompts: + instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + if freeze_text_encoder: + class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers + ) + + # Clear the memory here + if freeze_text_encoder and not train_dataset.custom_instance_prompts: + del tokenizers, text_encoders, text_encoder_one, text_encoder_two + free_memory() + + # if --train_text_encoder_ti we need add_special_tokens to be True for textual inversion + add_special_tokens_clip = True if args.train_text_encoder_ti else False + add_special_tokens_t5 = True if (args.train_text_encoder_ti and args.enable_t5_ti) else False + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + + if not train_dataset.custom_instance_prompts: + if freeze_text_encoder: + prompt_embeds = instance_prompt_hidden_states + pooled_prompt_embeds = instance_pooled_prompt_embeds + text_ids = instance_text_ids + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) + text_ids = torch.cat([text_ids, class_text_ids], dim=0) + # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) + # we need to tokenize and encode the batch prompts on all training steps + else: + tokens_one = tokenize_prompt( + tokenizer_one, args.instance_prompt, max_sequence_length=77, add_special_tokens=add_special_tokens_clip + ) + tokens_two = tokenize_prompt( + tokenizer_two, + args.instance_prompt, + max_sequence_length=args.max_sequence_length, + add_special_tokens=add_special_tokens_t5, + ) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt( + tokenizer_one, + args.class_prompt, + max_sequence_length=77, + add_special_tokens=add_special_tokens_clip, + ) + class_tokens_two = tokenize_prompt( + tokenizer_two, + args.class_prompt, + max_sequence_length=args.max_sequence_length, + add_special_tokens=add_special_tokens_t5, + ) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + + vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor + vae_config_block_out_channels = vae.config.block_out_channels + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=weight_dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + if args.validation_prompt is None: + del vae + free_memory() + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if not freeze_text_encoder: + if args.enable_t5_ti: + ( + transformer, + text_encoder_one, + text_encoder_two, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + transformer, + text_encoder_one, + text_encoder_two, + optimizer, + train_dataloader, + lr_scheduler, + ) + else: + transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler + ) + + else: + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-flux-dev-lora-advanced" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + logger.info(f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.") + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + logger.info(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + if args.train_text_encoder: + num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs) + num_train_epochs_transformer = int(args.train_transformer_frac * args.num_train_epochs) + elif args.train_text_encoder_ti: # args.train_text_encoder_ti + num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs) + num_train_epochs_transformer = int(args.train_transformer_frac * args.num_train_epochs) + + # flag used for textual inversion + pivoted_te = False + pivoted_tr = False + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + # if performing any kind of optimization of text_encoder params + if args.train_text_encoder or args.train_text_encoder_ti: + if epoch == num_train_epochs_text_encoder: + # flag to stop text encoder optimization + logger.info(f"PIVOT TE {epoch}") + pivoted_te = True + else: + # still optimizing the text encoder + if args.train_text_encoder: + text_encoder_one.train() + # set top parameter requires_grad = True for gradient checkpointing works + accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + elif args.train_text_encoder_ti: # textual inversion / pivotal tuning + text_encoder_one.train() + if args.enable_t5_ti: + text_encoder_two.train() + + if epoch == num_train_epochs_transformer: + # flag to stop transformer optimization + logger.info(f"PIVOT TRANSFORMER {epoch}") + pivoted_tr = True + + for step, batch in enumerate(train_dataloader): + if pivoted_te: + # stopping optimization of text_encoder params + optimizer.param_groups[te_idx]["lr"] = 0.0 + optimizer.param_groups[-1]["lr"] = 0.0 + elif pivoted_tr and not pure_textual_inversion: + logger.info(f"PIVOT TRANSFORMER {epoch}") + optimizer.param_groups[0]["lr"] = 0.0 + + with accelerator.accumulate(transformer): + prompts = batch["prompts"] + + # encode batch prompts when custom prompts are provided for each image - + if train_dataset.custom_instance_prompts: + if freeze_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + prompts, text_encoders, tokenizers + ) + else: + tokens_one = tokenize_prompt( + tokenizer_one, prompts, max_sequence_length=77, add_special_tokens=add_special_tokens_clip + ) + tokens_two = tokenize_prompt( + tokenizer_two, + prompts, + max_sequence_length=args.max_sequence_length, + add_special_tokens=add_special_tokens_t5, + ) + + if not freeze_text_encoder: + prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=[None, None], + text_input_ids_list=[tokens_one, tokens_two], + max_sequence_length=args.max_sequence_length, + device=accelerator.device, + prompt=prompts, + ) + + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor + model_input = model_input.to(dtype=weight_dtype) + + vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) + + latent_image_ids = FluxPipeline._prepare_latent_image_ids( + model_input.shape[0], + model_input.shape[2], + model_input.shape[3], + accelerator.device, + weight_dtype, + ) + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + packed_noisy_model_input = FluxPipeline._pack_latents( + noisy_model_input, + batch_size=model_input.shape[0], + num_channels_latents=model_input.shape[1], + height=model_input.shape[2], + width=model_input.shape[3], + ) + + # handle guidance + if transformer.config.guidance_embeds: + guidance = torch.tensor([args.guidance_scale], device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) + else: + guidance = None + + # Predict the noise residual + model_pred = transformer( + hidden_states=packed_noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + model_pred = FluxPipeline._unpack_latents( + model_pred, + height=int(model_input.shape[2] * vae_scale_factor / 2), + width=int(model_input.shape[3] * vae_scale_factor / 2), + vae_scale_factor=vae_scale_factor, + ) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = torch.mean( + (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape( + target_prior.shape[0], -1 + ), + 1, + ) + prior_loss = prior_loss.mean() + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + if not freeze_text_encoder: + if args.train_text_encoder: + params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters()) + elif pure_textual_inversion: + params_to_clip = itertools.chain( + text_encoder_one.parameters(), text_encoder_two.parameters() + ) + else: + params_to_clip = itertools.chain( + transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters() + ) + else: + params_to_clip = itertools.chain(transformer.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # every step, we reset the embeddings to the original embeddings. + if args.train_text_encoder_ti: + embedding_handler.retract_embeddings() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + # create pipeline + if freeze_text_encoder: + text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=accelerator.unwrap_model(text_encoder_one), + text_encoder_2=accelerator.unwrap_model(text_encoder_two), + transformer=accelerator.unwrap_model(transformer), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline_args = {"prompt": args.validation_prompt} + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + torch_dtype=weight_dtype, + ) + if freeze_text_encoder: + del text_encoder_one, text_encoder_two + free_memory() + elif args.train_text_encoder: + del text_encoder_two + free_memory() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer = unwrap_model(transformer) + transformer = transformer.to(weight_dtype) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + if args.train_text_encoder: + text_encoder_one = unwrap_model(text_encoder_one) + text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + else: + text_encoder_lora_layers = None + + if not pure_textual_inversion: + FluxPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + ) + + if args.train_text_encoder_ti: + embeddings_path = f"{args.output_dir}/{os.path.basename(args.output_dir)}_emb.safetensors" + embedding_handler.save_embeddings(embeddings_path) + + # Final inference + # Load previous pipeline + pipeline = FluxPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + if not pure_textual_inversion: + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + pipeline_args = {"prompt": args.validation_prompt} + images = log_validation( + pipeline=pipeline, + args=args, + accelerator=accelerator, + pipeline_args=pipeline_args, + epoch=epoch, + torch_dtype=weight_dtype, + is_final_validation=True, + ) + + save_model_card( + model_id if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + train_text_encoder_ti=args.train_text_encoder_ti, + enable_t5_ti=args.enable_t5_ti, + pure_textual_inversion=pure_textual_inversion, + token_abstraction_dict=train_dataset.token_abstraction_dict, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + ) + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py index 7e1a0298ba1d..024722536d88 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py @@ -71,7 +71,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 5222c8afe6f1..bc06cc9213dc 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -79,7 +79,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py index 0fdca2850784..4ef392baa2b5 100644 --- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py +++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py @@ -61,7 +61,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index ece2228147e2..011466bc7d58 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -52,7 +52,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/community/README.md b/examples/community/README.md index 267c8f4bb904..4f16f65df8fa 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -4336,19 +4336,19 @@ The Abstract of the paper: **64x64** :-------------------------: -| bird_64 | +| bird_64_64 | - `256×256, nesting_level=1`: 1.776 GiB. With `150` DDIM inference steps: **64x64** | **256x256** :-------------------------:|:-------------------------: -| 64x64 | 256x256 | +| bird_256_64 | bird_256_256 | -- `1024×1024, nesting_level=2`: 1.792 GiB. As one can realize the cost of adding another layer is really negligible. With `250` DDIM inference steps: +- `1024×1024, nesting_level=2`: 1.792 GiB. As one can realize the cost of adding another layer is really negligible in this context! With `250` DDIM inference steps: **64x64** | **256x256** | **1024x1024** :-------------------------:|:-------------------------:|:-------------------------: -| 64x64 | 256x256 | 1024x1024 | +| bird_1024_64 | bird_1024_256 | bird_1024_1024 | ```py from diffusers import DiffusionPipeline @@ -4362,8 +4362,7 @@ pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-model prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree" prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed" -negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy" -image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images +image = pipe(prompt, num_inference_steps=50).images make_image_grid(image, rows=1, cols=len(image)) # pipe.change_nesting_level() # 0, 1, or 2 diff --git a/examples/community/README_community_scripts.md b/examples/community/README_community_scripts.md index 8432b4e82c9f..2c2f549a2bd5 100644 --- a/examples/community/README_community_scripts.md +++ b/examples/community/README_community_scripts.md @@ -8,6 +8,7 @@ If a community script doesn't work as expected, please open an issue and ping th |:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:| | Using IP-Adapter with negative noise | Using negative noise with IP-adapter to better control the generation (see the [original post](https://github.com/huggingface/diffusers/discussions/7167) on the forum for more details) | [IP-Adapter Negative Noise](#ip-adapter-negative-noise) | | [Álvaro Somoza](https://github.com/asomoza)| | asymmetric tiling |configure seamless image tiling independently for the X and Y axes | [Asymmetric Tiling](#asymmetric-tiling ) | | [alexisrolland](https://github.com/alexisrolland)| +| Prompt scheduling callback |Allows changing prompts during a generation | [Prompt Scheduling](#prompt-scheduling ) | | [hlky](https://github.com/hlky)| ## Example usages @@ -229,4 +230,86 @@ seamless_tiling(pipeline=pipeline, x_axis=False, y_axis=False) torch.cuda.empty_cache() image.save('image.png') -``` \ No newline at end of file +``` + +### Prompt Scheduling callback + +Prompt scheduling callback allows changing prompts during a generation, like [prompt editing in A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-editing) + +```python +from diffusers import StableDiffusionPipeline +from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks +from diffusers.configuration_utils import register_to_config +import torch +from typing import Any, Dict, Optional + + +pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", + torch_dtype=torch.float16, + variant="fp16", + use_safetensors=True, +).to("cuda") +pipeline.safety_checker = None +pipeline.requires_safety_checker = False + + +class SDPromptScheduleCallback(PipelineCallback): + @register_to_config + def __init__( + self, + prompt: str, + negative_prompt: Optional[str] = None, + num_images_per_prompt: int = 1, + cutoff_step_ratio=1.0, + cutoff_step_index=None, + ): + super().__init__( + cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index + ) + + tensor_inputs = ["prompt_embeds"] + + def callback_fn( + self, pipeline, step_index, timestep, callback_kwargs + ) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index + if cutoff_step_index is not None + else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt( + prompt=self.config.prompt, + negative_prompt=self.config.negative_prompt, + device=pipeline._execution_device, + num_images_per_prompt=self.config.num_images_per_prompt, + do_classifier_free_guidance=pipeline.do_classifier_free_guidance, + ) + if pipeline.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + return callback_kwargs + +callback = MultiPipelineCallbacks( + [ + SDPromptScheduleCallback( + prompt="Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski", + negative_prompt="Deformed, ugly, bad anatomy", + cutoff_step_ratio=0.25, + ) + ] +) + +image = pipeline( + prompt="Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski", + negative_prompt="Deformed, ugly, bad anatomy", + callback_on_step_end=callback, + callback_on_step_end_tensor_inputs=["prompt_embeds"], +).images[0] +``` diff --git a/examples/community/marigold_depth_estimation.py b/examples/community/marigold_depth_estimation.py index 92f01d046ef9..a8f406309a52 100644 --- a/examples/community/marigold_depth_estimation.py +++ b/examples/community/marigold_depth_estimation.py @@ -43,7 +43,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") class MarigoldDepthOutput(BaseOutput): diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 7ef1438f7204..7ac0ab542910 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -107,15 +107,16 @@ >>> # nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64 >>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models", - >>> custom_pipeline="matryoshka").to("cuda") + ... nesting_level=0, + ... trust_remote_code=False, # One needs to give permission for this code to run + ... ).to("cuda") >>> prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree" >>> prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed" - >>> negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy" - >>> image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images + >>> image = pipe(prompt, num_inference_steps=50).images >>> make_image_grid(image, rows=1, cols=len(image)) - >>> pipe.change_nesting_level() # 0, 1, or 2 + >>> # pipe.change_nesting_level() # 0, 1, or 2 >>> # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively. ``` """ @@ -420,6 +421,7 @@ def __init__( self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) self.scales = None + self.schedule_shifted_power = 1.0 def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor: """ @@ -532,6 +534,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic def get_schedule_shifted(self, alpha_prod, scale_factor=None): if (scale_factor is not None) and (scale_factor > 1): # rescale noise schedule + scale_factor = scale_factor**self.schedule_shifted_power snr = alpha_prod / (1 - alpha_prod) scaled_snr = snr / scale_factor alpha_prod = 1 / (1 + 1 / scaled_snr) @@ -639,17 +642,14 @@ def step( # 4. Clip or threshold "predicted x_0" if self.config.thresholding: if len(model_output) > 1: - pred_original_sample = [ - self._threshold_sample(p_o_s * scale) / scale - for p_o_s, scale in zip(pred_original_sample, self.scales) - ] + pred_original_sample = [self._threshold_sample(p_o_s) for p_o_s in pred_original_sample] else: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: if len(model_output) > 1: pred_original_sample = [ - (p_o_s * scale).clamp(-self.config.clip_sample_range, self.config.clip_sample_range) / scale - for p_o_s, scale in zip(pred_original_sample, self.scales) + p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range) + for p_o_s in pred_original_sample ] else: pred_original_sample = pred_original_sample.clamp( @@ -3816,6 +3816,8 @@ def __init__( if hasattr(unet, "nest_ratio"): scheduler.scales = unet.nest_ratio + [1] + if nesting_level == 2: + scheduler.schedule_shifted_power = 2.0 self.register_modules( text_encoder=text_encoder, @@ -3842,12 +3844,14 @@ def change_nesting_level(self, nesting_level: int): ).to(self.device) self.config.nesting_level = 1 self.scheduler.scales = self.unet.nest_ratio + [1] + self.scheduler.schedule_shifted_power = 1.0 elif nesting_level == 2: self.unet = NestedUNet2DConditionModel.from_pretrained( "tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2" ).to(self.device) self.config.nesting_level = 2 self.scheduler.scales = self.unet.nest_ratio + [1] + self.scheduler.schedule_shifted_power = 2.0 else: raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.") @@ -4627,8 +4631,8 @@ def __call__( image = latents if self.scheduler.scales is not None: - for i, (img, scale) in enumerate(zip(image, self.scheduler.scales)): - image[i] = self.image_processor.postprocess(img * scale, output_type=output_type)[0] + for i, img in enumerate(image): + image[i] = self.image_processor.postprocess(img, output_type=output_type)[0] else: image = self.image_processor.postprocess(image, output_type=output_type) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py index 611026675daf..0750df79eb0d 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py @@ -73,7 +73,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py index 8090926974c4..493742691286 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py @@ -66,7 +66,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py index fa7e7f1febee..824f148c58fd 100644 --- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py @@ -79,7 +79,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 12d7db09a361..a334c27e7d86 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -72,7 +72,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index cc5e6812127e..6e5e85172f14 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -78,7 +78,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/controlnet/README_sd3.md b/examples/controlnet/README_sd3.md index 1788e07a21d6..7a7b4841125f 100644 --- a/examples/controlnet/README_sd3.md +++ b/examples/controlnet/README_sd3.md @@ -104,7 +104,7 @@ from diffusers.utils import load_image import torch base_model_path = "stabilityai/stable-diffusion-3-medium-diffusers" -controlnet_path = "sd3-controlnet-out/checkpoint-6500/controlnet" +controlnet_path = "DavyMorgan/sd3-controlnet-out" controlnet = SD3ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) pipe = StableDiffusion3ControlNetPipeline.from_pretrained( diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 2e902db7ffc7..a2aa266cdfbc 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -60,7 +60,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) @@ -1048,7 +1048,9 @@ def load_model_hook(models, input_dir): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to( + dtype=weight_dtype + ) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 1aa9e881fca5..44c286cd2a40 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -60,7 +60,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index 5969218f3c3e..ca822b16eae2 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -65,7 +65,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index 4fae8a072c6f..2bb68220e268 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -50,7 +50,7 @@ ) from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory -from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.torch_utils import is_compiled_module @@ -59,22 +59,11 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.30.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) -def image_grid(imgs, rows, cols): - assert len(imgs) == rows * cols - - w, h = imgs[0].size - grid = Image.new("RGB", size=(cols * w, rows * h)) - - for i, img in enumerate(imgs): - grid.paste(img, box=(i % cols * w, i // cols * h)) - return grid - - def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_validation=False): logger.info("Running validation... ") @@ -224,7 +213,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N validation_image.save(os.path.join(repo_folder, "image_control.png")) img_str += f"prompt: {validation_prompt}\n" images = [validation_image] + images - image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) img_str += f"![images_{i})](./images_{i}.png)\n" model_description = f""" diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 877ca6135849..c034c027cbcd 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -61,7 +61,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): @@ -1210,7 +1210,9 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to( + dtype=weight_dtype + ) # ControlNet conditioning. controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index e498ca98b1c7..151817247350 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -63,7 +63,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 5099107118e4..4b614807cfc4 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -63,7 +63,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 29fd5e78535d..3023b28aca7f 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -35,7 +35,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") # Cache compiled models across invocations of this script. cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache")) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 8e0f4e09a461..db4788281cf2 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -64,7 +64,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 5d7d697bb21d..bf778693a88d 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -70,7 +70,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 11cba745cc4a..b09e5b38b2b1 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -72,7 +72,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 8d0b6853eeec..4b39dcfe41b0 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -72,7 +72,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) @@ -86,6 +86,15 @@ def save_model_card( validation_prompt=None, repo_folder=None, ): + if "large" in base_model: + model_variant = "SD3.5-Large" + license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md" + variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"] + else: + model_variant = "SD3" + license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md" + variant_tags = ["sd3", "sd3-diffusers"] + widget_dict = [] if images is not None: for i, image in enumerate(images): @@ -95,7 +104,7 @@ def save_model_card( ) model_description = f""" -# SD3 DreamBooth LoRA - {repo_id} +# {model_variant} DreamBooth LoRA - {repo_id} @@ -120,7 +129,7 @@ def save_model_card( ```py from diffusers import AutoPipelineForText2Image import torch -pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers', torch_dtype=torch.float16).to('cuda') +pipeline = AutoPipelineForText2Image.from_pretrained({base_model}, torch_dtype=torch.float16).to('cuda') pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors') image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0] ``` @@ -135,7 +144,7 @@ def save_model_card( ## License -Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE). +Please adhere to the licensing terms as described [here]({license_url}). """ model_card = load_or_create_model_card( repo_id_or_path=repo_id, @@ -151,11 +160,11 @@ def save_model_card( "diffusers-training", "diffusers", "lora", - "sd3", - "sd3-diffusers", "template:sd-lora", ] + tags += variant_tags + model_card = populate_model_card(model_card, tags=tags) model_card.save(os.path.join(repo_folder, "README.md")) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 016464165c44..bf8c8f7d0578 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -78,7 +78,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 455ba5a9293d..5d10345304ab 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -63,7 +63,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) @@ -77,6 +77,15 @@ def save_model_card( validation_prompt=None, repo_folder=None, ): + if "large" in base_model: + model_variant = "SD3.5-Large" + license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md" + variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"] + else: + model_variant = "SD3" + license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md" + variant_tags = ["sd3", "sd3-diffusers"] + widget_dict = [] if images is not None: for i, image in enumerate(images): @@ -86,7 +95,7 @@ def save_model_card( ) model_description = f""" -# SD3 DreamBooth - {repo_id} +# {model_variant} DreamBooth - {repo_id} @@ -113,7 +122,7 @@ def save_model_card( ## License -Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`. +Please adhere to the licensing terms as described `[here]({license_url})`. """ model_card = load_or_create_model_card( repo_id_or_path=repo_id, @@ -128,10 +137,9 @@ def save_model_card( "text-to-image", "diffusers-training", "diffusers", - "sd3", - "sd3-diffusers", "template:sd-lora", ] + tags += variant_tags model_card = populate_model_card(model_card, tags=tags) model_card.save(os.path.join(repo_folder, "README.md")) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index 3cb0c6702599..125368841fa8 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -57,7 +57,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index c88be6d16d88..4cb9f0e1c544 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -60,7 +60,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index 9caa3694d636..40016f797341 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -52,7 +52,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index 23f5d342b396..3ec622c09239 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -46,7 +46,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index 6ed3377db131..fbd843bc3307 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -46,7 +46,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index 448429444448..c264a4ce8c7c 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -51,7 +51,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py index 02a064fa81ed..e694d709360c 100644 --- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py +++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py @@ -60,7 +60,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 684bf352a6c1..6857df61d0c2 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -57,7 +57,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index 4a80067d693d..712bc34429a0 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -49,7 +49,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 379519b4c812..5f432fcc7adf 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -56,7 +56,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index fe098c8638d5..9a4fa23fada3 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -68,7 +68,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index bcf0fa9eb0ac..b34feb6f715c 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -55,7 +55,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) if is_torch_npu_available(): diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 6b710531836b..43e8bf4e9072 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -81,7 +81,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index ee7b1580d145..fff633e75684 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -56,7 +56,7 @@ # ------------------------------------------------------------------------------ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = logging.getLogger(__name__) diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index a4629f0f43d6..3a9da9fb11df 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -76,7 +76,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 1f5e1de240cb..a80e4c55190d 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -29,7 +29,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py index d16dce921896..b56e39847983 100644 --- a/examples/vqgan/train_vqgan.py +++ b/examples/vqgan/train_vqgan.py @@ -50,7 +50,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py index b4c9a44bb5b2..d57d910599ee 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py @@ -50,7 +50,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py index eba8de69203a..2d9df8387333 100644 --- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py +++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py @@ -51,7 +51,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.31.0.dev0") +check_min_version("0.32.0.dev0") logger = get_logger(__name__, log_level="INFO") diff --git a/scripts/convert_sd3_to_diffusers.py b/scripts/convert_sd3_to_diffusers.py index 4f32745dae75..1f9c434b39d0 100644 --- a/scripts/convert_sd3_to_diffusers.py +++ b/scripts/convert_sd3_to_diffusers.py @@ -16,10 +16,9 @@ parser = argparse.ArgumentParser() parser.add_argument("--checkpoint_path", type=str) parser.add_argument("--output_path", type=str) -parser.add_argument("--dtype", type=str, default="fp16") +parser.add_argument("--dtype", type=str) args = parser.parse_args() -dtype = torch.float16 if args.dtype == "fp16" else torch.float32 def load_original_checkpoint(ckpt_path): @@ -40,7 +39,9 @@ def swap_scale_shift(weight, dim): return new_weight -def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_layers, caption_projection_dim): +def convert_sd3_transformer_checkpoint_to_diffusers( + original_state_dict, num_layers, caption_projection_dim, dual_attention_layers, has_qk_norm +): converted_state_dict = {} # Positional and patch embeddings. @@ -110,6 +111,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v]) converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + # qk norm + if has_qk_norm: + converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn.ln_k.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.context_block.attn.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.context_block.attn.ln_k.weight" + ) + # output projections. converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = original_state_dict.pop( f"joint_blocks.{i}.x_block.attn.proj.weight" @@ -125,6 +141,39 @@ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_lay f"joint_blocks.{i}.context_block.attn.proj.bias" ) + # attn2 + if i in dual_attention_layers: + # Q, K, V + sample_q2, sample_k2, sample_v2 = torch.chunk( + original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0 + ) + sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk( + original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0 + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias]) + + # qk norm + if has_qk_norm: + converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn2.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn2.ln_k.weight" + ) + + # output projections. + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn2.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = original_state_dict.pop( + f"joint_blocks.{i}.x_block.attn2.proj.bias" + ) + # norms. converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = original_state_dict.pop( f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight" @@ -195,25 +244,79 @@ def is_vae_in_checkpoint(original_state_dict): ) +def get_attn2_layers(state_dict): + attn2_layers = [] + for key in state_dict.keys(): + if "attn2." in key: + # Extract the layer number from the key + layer_num = int(key.split(".")[1]) + attn2_layers.append(layer_num) + return tuple(sorted(set(attn2_layers))) + + +def get_pos_embed_max_size(state_dict): + num_patches = state_dict["pos_embed"].shape[1] + pos_embed_max_size = int(num_patches**0.5) + return pos_embed_max_size + + +def get_caption_projection_dim(state_dict): + caption_projection_dim = state_dict["context_embedder.weight"].shape[0] + return caption_projection_dim + + def main(args): original_ckpt = load_original_checkpoint(args.checkpoint_path) + original_dtype = next(iter(original_ckpt.values())).dtype + + # Initialize dtype with a default value + dtype = None + + if args.dtype is None: + dtype = original_dtype + elif args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + elif args.dtype == "fp32": + dtype = torch.float32 + else: + raise ValueError(f"Unsupported dtype: {args.dtype}") + + if dtype != original_dtype: + print( + f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution." + ) + num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401 - caption_projection_dim = 1536 + + caption_projection_dim = get_caption_projection_dim(original_ckpt) + + # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 + attn2_layers = get_attn2_layers(original_ckpt) + + # sd3.5 use qk norm("rms_norm") + has_qk_norm = any("ln_q" in key for key in original_ckpt.keys()) + + # sd3.5 2b use pox_embed_max_size=384 and sd3.0 and sd3.5 8b use 192 + pos_embed_max_size = get_pos_embed_max_size(original_ckpt) converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers( - original_ckpt, num_layers, caption_projection_dim + original_ckpt, num_layers, caption_projection_dim, attn2_layers, has_qk_norm ) with CTX(): transformer = SD3Transformer2DModel( - sample_size=64, + sample_size=128, patch_size=2, in_channels=16, joint_attention_dim=4096, num_layers=num_layers, caption_projection_dim=caption_projection_dim, - num_attention_heads=24, - pos_embed_max_size=192, + num_attention_heads=num_layers, + pos_embed_max_size=pos_embed_max_size, + qk_norm="rms_norm" if has_qk_norm else None, + dual_attention_layers=attn2_layers, ) if is_accelerate_available(): load_model_dict_into_meta(transformer, converted_transformer_state_dict) diff --git a/setup.py b/setup.py index 89e18cab629b..d82ecad86771 100644 --- a/setup.py +++ b/setup.py @@ -130,7 +130,7 @@ "regex!=2019.12.17", "requests", "tensorboard", - "torch>=1.4", + "torch>=1.4,<2.5.0", "torchvision", "transformers>=4.41.2", "urllib3<=2.0.0", @@ -254,7 +254,7 @@ def run(self): setup( name="diffusers", - version="0.31.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="0.32.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description="State-of-the-art diffusion in PyTorch and JAX.", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 019d744730ab..789458a26299 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.31.0.dev0" +__version__ = "0.32.0.dev0" from typing import TYPE_CHECKING @@ -31,6 +31,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], + "quantizers.quantization_config": ["BitsAndBytesConfig"], "schedulers": [], "utils": [ "OptionalDependencyNotAvailable", @@ -124,7 +125,6 @@ "VQModel", ] ) - _import_structure["optimization"] = [ "get_constant_schedule", "get_constant_schedule_with_warmup", @@ -156,6 +156,7 @@ "StableDiffusionMixin", ] ) + _import_structure["quantizers"] = ["DiffusersQuantizer"] _import_structure["schedulers"].extend( [ "AmusedScheduler", @@ -256,6 +257,7 @@ "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", "CLIPImageProjection", + "CogVideoXFunControlPipeline", "CogVideoXImageToVideoPipeline", "CogVideoXPipeline", "CogVideoXVideoToVideoPipeline", @@ -537,6 +539,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .configuration_utils import ConfigMixin + from .quantizers.quantization_config import BitsAndBytesConfig try: if not is_onnx_available(): @@ -631,6 +634,7 @@ ScoreSdeVePipeline, StableDiffusionMixin, ) + from .quantizers import DiffusersQuantizer from .schedulers import ( AmusedScheduler, CMStochasticIterativeScheduler, @@ -711,6 +715,7 @@ AudioLDMPipeline, AuraFlowPipeline, CLIPImageProjection, + CogVideoXFunControlPipeline, CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline, diff --git a/src/diffusers/callbacks.py b/src/diffusers/callbacks.py index 38542407e31f..4b8b15368c47 100644 --- a/src/diffusers/callbacks.py +++ b/src/diffusers/callbacks.py @@ -97,13 +97,17 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s class SDXLCFGCutoffCallback(PipelineCallback): """ - Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or - `cutoff_step_index`), this callback will disable the CFG. + Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by + `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG. Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. """ - tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"] + tensor_inputs = [ + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + ] def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: cutoff_step_ratio = self.config.cutoff_step_ratio @@ -129,6 +133,55 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s callback_kwargs[self.tensor_inputs[0]] = prompt_embeds callback_kwargs[self.tensor_inputs[1]] = add_text_embeds callback_kwargs[self.tensor_inputs[2]] = add_time_ids + + return callback_kwargs + + +class SDXLControlnetCFGCutoffCallback(PipelineCallback): + """ + Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by + `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG. + + Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. + """ + + tensor_inputs = [ + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + "image", + ] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds = callback_kwargs[self.tensor_inputs[0]] + prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. + + add_text_embeds = callback_kwargs[self.tensor_inputs[1]] + add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens + + add_time_ids = callback_kwargs[self.tensor_inputs[2]] + add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector + + # For Controlnet + image = callback_kwargs[self.tensor_inputs[3]] + image = image[-1:] + + pipeline._guidance_scale = 0.0 + + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + callback_kwargs[self.tensor_inputs[1]] = add_text_embeds + callback_kwargs[self.tensor_inputs[2]] = add_time_ids + callback_kwargs[self.tensor_inputs[3]] = image + return callback_kwargs diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 3dccd785cae4..11d45dc64d97 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -510,6 +510,9 @@ def extract_init_dict(cls, config_dict, **kwargs): # remove private attributes config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")} + # remove quantization_config + config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config"} + # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments init_dict = {} for key in expected_keys: @@ -586,10 +589,19 @@ def to_json_saveable(value): value = value.as_posix() return value + if "quantization_config" in config_dict: + config_dict["quantization_config"] = ( + config_dict.quantization_config.to_dict() + if not isinstance(config_dict.quantization_config, dict) + else config_dict.quantization_config + ) + config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} # Don't save "_ignore_files" or "_use_default_values" config_dict.pop("_ignore_files", None) config_dict.pop("_use_default_values", None) + # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + _ = config_dict.pop("_pre_quantization_dtype", None) return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 9e7bf242eca7..0e421b71e48d 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -38,7 +38,7 @@ "regex": "regex!=2019.12.17", "requests": "requests", "tensorboard": "tensorboard", - "torch": "torch>=1.4", + "torch": "torch>=1.4,<2.5.0", "torchvision": "torchvision", "transformers": "transformers>=4.41.2", "urllib3": "urllib3<=2.0.0", diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 236fbd0c2295..d1bad8b5a7cd 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -75,6 +75,7 @@ "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight", "stable_cascade_stage_c": "clip_txt_mapper.weight", "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias", + "sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight", "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe", "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias", "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", @@ -113,6 +114,9 @@ "sd3": { "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers", }, + "sd35_large": { + "pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large", + }, "animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"}, "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"}, "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"}, @@ -504,9 +508,12 @@ def infer_diffusers_model_type(checkpoint): ): model_type = "stable_cascade_stage_b" - elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint: + elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216: model_type = "sd3" + elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint: + model_type = "sd35_large" + elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint: if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint: model_type = "animatediff_scribble" @@ -1670,6 +1677,22 @@ def swap_scale_shift(weight, dim): return new_weight +def get_attn2_layers(state_dict): + attn2_layers = [] + for key in state_dict.keys(): + if "attn2." in key: + # Extract the layer number from the key + layer_num = int(key.split(".")[1]) + attn2_layers.append(layer_num) + + return tuple(sorted(set(attn2_layers))) + + +def get_caption_projection_dim(state_dict): + caption_projection_dim = state_dict["context_embedder.weight"].shape[0] + return caption_projection_dim + + def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} keys = list(checkpoint.keys()) @@ -1678,7 +1701,10 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401 - caption_projection_dim = 1536 + dual_attention_layers = get_attn2_layers(checkpoint) + + caption_projection_dim = get_caption_projection_dim(checkpoint) + has_qk_norm = any("ln_q" in key for key in checkpoint.keys()) # Positional and patch embeddings. converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed") @@ -1735,6 +1761,21 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v]) converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + # qk norm + if has_qk_norm: + converted_state_dict[f"transformer_blocks.{i}.attn.norm_q.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_k.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn.ln_k.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_q.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn.norm_added_k.weight"] = checkpoint.pop( + f"joint_blocks.{i}.context_block.attn.ln_k.weight" + ) + # output projections. converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop( f"joint_blocks.{i}.x_block.attn.proj.weight" @@ -1750,6 +1791,38 @@ def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): f"joint_blocks.{i}.context_block.attn.proj.bias" ) + if i in dual_attention_layers: + # Q, K, V + sample_q2, sample_k2, sample_v2 = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0 + ) + sample_q2_bias, sample_k2_bias, sample_v2_bias = torch.chunk( + checkpoint.pop(f"joint_blocks.{i}.x_block.attn2.qkv.bias"), 3, dim=0 + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = torch.cat([sample_q2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = torch.cat([sample_q2_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = torch.cat([sample_k2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = torch.cat([sample_k2_bias]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = torch.cat([sample_v2]) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = torch.cat([sample_v2_bias]) + + # qk norm + if has_qk_norm: + converted_state_dict[f"transformer_blocks.{i}.attn2.norm_q.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.ln_q.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.norm_k.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.ln_k.weight" + ) + + # output projections. + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.proj.weight" + ) + converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop( + f"joint_blocks.{i}.x_block.attn2.proj.bias" + ) + # norms. converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight" diff --git a/src/diffusers/loaders/textual_inversion.py b/src/diffusers/loaders/textual_inversion.py index 574b89233cc1..30098c955d6b 100644 --- a/src/diffusers/loaders/textual_inversion.py +++ b/src/diffusers/loaders/textual_inversion.py @@ -561,6 +561,8 @@ def unload_textual_inversion( tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id key_id += 1 tokenizer._update_trie() + # set correct total vocab size after removing tokens + tokenizer._update_total_vocab_size() # Delete from text encoder text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 84db0d061768..02ed1f965abf 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -22,7 +22,7 @@ from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding -from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm +from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX logger = logging.get_logger(__name__) @@ -100,13 +100,25 @@ class JointTransformerBlock(nn.Module): processing of `context` conditions. """ - def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + context_pre_only: bool = False, + qk_norm: Optional[str] = None, + use_dual_attention: bool = False, + ): super().__init__() + self.use_dual_attention = use_dual_attention self.context_pre_only = context_pre_only context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero" - self.norm1 = AdaLayerNormZero(dim) + if use_dual_attention: + self.norm1 = SD35AdaLayerNormZeroX(dim) + else: + self.norm1 = AdaLayerNormZero(dim) if context_norm_type == "ada_norm_continous": self.norm1_context = AdaLayerNormContinuous( @@ -118,12 +130,14 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl raise ValueError( f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`" ) + if hasattr(F, "scaled_dot_product_attention"): processor = JointAttnProcessor2_0() else: raise ValueError( "The current PyTorch version does not support the `scaled_dot_product_attention` function." ) + self.attn = Attention( query_dim=dim, cross_attention_dim=None, @@ -134,8 +148,25 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl context_pre_only=context_pre_only, bias=True, processor=processor, + qk_norm=qk_norm, + eps=1e-6, ) + if use_dual_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + processor=processor, + qk_norm=qk_norm, + eps=1e-6, + ) + else: + self.attn2 = None + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") @@ -159,7 +190,12 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor ): - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + if self.use_dual_attention: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( + hidden_states, emb=temb + ) + else: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) if self.context_pre_only: norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) @@ -177,6 +213,11 @@ def forward( attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output + if self.use_dual_attention: + attn_output2 = self.attn2(hidden_states=norm_hidden_states2) + attn_output2 = gate_msa2.unsqueeze(1) * attn_output2 + hidden_states = hidden_states + attn_output2 + norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] if self._chunk_size is not None: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index d333590982e3..e735c4ee7d17 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -193,7 +193,7 @@ def __init__( self.norm_q = RMSNorm(dim_head, eps=eps) self.norm_k = RMSNorm(dim_head, eps=eps) else: - raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'") if cross_attention_norm is None: self.norm_cross = None @@ -250,6 +250,10 @@ def __init__( elif qk_norm == "rms_norm": self.norm_added_q = RMSNorm(dim_head, eps=eps) self.norm_added_k = RMSNorm(dim_head, eps=eps) + else: + raise ValueError( + f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`" + ) else: self.norm_added_q = None self.norm_added_k = None @@ -1050,61 +1054,72 @@ def __call__( ) -> torch.FloatTensor: residual = hidden_states - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - context_input_ndim = encoder_hidden_states.ndim - if context_input_ndim == 4: - batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size = encoder_hidden_states.shape[0] + batch_size = hidden_states.shape[0] # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - # attention - query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) - key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) - value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) - inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - # Split the attention outputs. - hidden_states, encoder_hidden_states = ( - hidden_states[:, : residual.shape[1]], - hidden_states[:, residual.shape[1] :], - ) + if encoder_hidden_states is not None: + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) - if not attn.context_pre_only: - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - return hidden_states, encoder_hidden_states + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states class PAGJointAttnProcessor2_0: diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 7834206ddb4a..68b49d72acc5 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -1182,7 +1182,8 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor: frame_batch_size = self.num_sample_frames_batch_size # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. - num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 + # As the extra single frame is handled inside the loop, it is not required to round up here. + num_batches = max(num_frames // frame_batch_size, 1) conv_cache = None enc = [] @@ -1330,7 +1331,8 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: row = [] for j in range(0, width, overlap_width): # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. - num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 + # As the extra single frame is handled inside the loop, it is not required to round up here. + num_batches = max(num_frames // frame_batch_size, 1) conv_cache = None time = [] @@ -1409,7 +1411,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod for i in range(0, height, overlap_height): row = [] for j in range(0, width, overlap_width): - num_batches = num_frames // frame_batch_size + num_batches = max(num_frames // frame_batch_size, 1) conv_cache = None time = [] diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index c9eb664443b5..932a94571107 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -25,6 +25,7 @@ import torch from huggingface_hub.utils import EntryNotFoundError +from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, @@ -54,11 +55,36 @@ # Adapted from `transformers` (see modeling_utils.py) -def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype): +def _determine_device_map( + model: torch.nn.Module, device_map, max_memory, torch_dtype, keep_in_fp32_modules=[], hf_quantizer=None +): if isinstance(device_map, str): + special_dtypes = {} + if hf_quantizer is not None: + special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, torch_dtype)) + special_dtypes.update( + { + name: torch.float32 + for name, _ in model.named_parameters() + if any(m in name for m in keep_in_fp32_modules) + } + ) + + target_dtype = torch_dtype + if hf_quantizer is not None: + target_dtype = hf_quantizer.adjust_target_dtype(target_dtype) + no_split_modules = model._get_no_split_modules(device_map) device_map_kwargs = {"no_split_module_classes": no_split_modules} + if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: + device_map_kwargs["special_dtypes"] = special_dtypes + elif len(special_dtypes) > 0: + logger.warning( + "This model has some weights that should be kept in higher precision, you need to upgrade " + "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)." + ) + if device_map != "sequential": max_memory = get_balanced_memory( model, @@ -70,8 +96,14 @@ def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_ else: max_memory = get_max_memory(max_memory) + if hf_quantizer is not None: + max_memory = hf_quantizer.adjust_max_memory(max_memory) + device_map_kwargs["max_memory"] = max_memory - device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) + device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) + + if hf_quantizer is not None: + hf_quantizer.validate_environment(device_map=device_map) return device_map @@ -100,6 +132,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ """ Reads a checkpoint file, returning properly formatted errors if they arise. """ + # TODO: We merge the sharded checkpoints in case we're doing quantization. We can revisit this change + # when refactoring the _merge_sharded_checkpoints() method later. + if isinstance(checkpoint_file, dict): + return checkpoint_file try: file_extension = os.path.basename(checkpoint_file).split(".")[-1] if file_extension == SAFETENSORS_FILE_EXTENSION: @@ -137,29 +173,67 @@ def load_model_dict_into_meta( device: Optional[Union[str, torch.device]] = None, dtype: Optional[Union[str, torch.dtype]] = None, model_name_or_path: Optional[str] = None, + hf_quantizer=None, + keep_in_fp32_modules=None, ) -> List[str]: - device = device or torch.device("cpu") + if hf_quantizer is None: + device = device or torch.device("cpu") dtype = dtype or torch.float32 + is_quantized = hf_quantizer is not None + is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) - - unexpected_keys = [] empty_state_dict = model.state_dict() + unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict] + for param_name, param in state_dict.items(): if param_name not in empty_state_dict: - unexpected_keys.append(param_name) continue + set_module_kwargs = {} + # We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params + # in int/uint/bool and not cast them. + # TODO: revisit cases when param.dtype == torch.float8_e4m3fn + if torch.is_floating_point(param): + if ( + keep_in_fp32_modules is not None + and any( + module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules + ) + and dtype == torch.float16 + ): + param = param.to(torch.float32) + if accepts_dtype: + set_module_kwargs["dtype"] = torch.float32 + else: + param = param.to(dtype) + if accepts_dtype: + set_module_kwargs["dtype"] = dtype + + # bnb params are flattened. if empty_state_dict[param_name].shape != param.shape: - model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" - raise ValueError( - f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." - ) + if ( + is_quant_method_bnb + and hf_quantizer.pre_quantized + and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) + ): + hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape) + elif not is_quant_method_bnb: + model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" + raise ValueError( + f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." + ) - if accepts_dtype: - set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) + if is_quantized and ( + hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device) + ): + hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys) else: - set_module_tensor_to_device(model, param_name, device, value=param) + if accepts_dtype: + set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs) + else: + set_module_tensor_to_device(model, param_name, device, value=param) + return unexpected_keys @@ -231,6 +305,35 @@ def _fetch_index_file( return index_file +# Adapted from +# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64 +def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata): + weight_map = sharded_metadata.get("weight_map", None) + if weight_map is None: + raise KeyError("'weight_map' key not found in the shard index file.") + + # Collect all unique safetensors files from weight_map + files_to_load = set(weight_map.values()) + is_safetensors = all(f.endswith(".safetensors") for f in files_to_load) + merged_state_dict = {} + + # Load tensors from each unique file + for file_name in files_to_load: + part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name) + if not os.path.exists(part_file_path): + raise FileNotFoundError(f"Part file {file_name} not found.") + + if is_safetensors: + with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: + for tensor_key in f.keys(): + if tensor_key in weight_map: + merged_state_dict[tensor_key] = f.get_tensor(tensor_key) + else: + merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu")) + + return merged_state_dict + + def _fetch_index_file_legacy( is_local, pretrained_model_name_or_path, diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index ad3433889fca..4a486fd4ce40 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -14,13 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import itertools import json import os import re from collections import OrderedDict -from functools import partial +from functools import partial, wraps from pathlib import Path from typing import Any, Callable, List, Optional, Tuple, Union @@ -31,6 +32,8 @@ from torch import Tensor, nn from .. import __version__ +from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer +from ..quantizers.quantization_config import QuantizationMethod from ..utils import ( CONFIG_NAME, FLAX_WEIGHTS_NAME, @@ -43,6 +46,8 @@ _get_model_file, deprecate, is_accelerate_available, + is_bitsandbytes_available, + is_bitsandbytes_version, is_torch_version, logging, ) @@ -56,6 +61,7 @@ _fetch_index_file, _fetch_index_file_legacy, _load_state_dict_into_model, + _merge_sharded_checkpoints, load_model_dict_into_meta, load_state_dict, ) @@ -125,6 +131,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): _supports_gradient_checkpointing = False _keys_to_ignore_on_load_unexpected = None _no_split_modules = None + _keep_in_fp32_modules = None def __init__(self): super().__init__() @@ -308,6 +315,19 @@ def save_pretrained( logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return + hf_quantizer = getattr(self, "hf_quantizer", None) + if hf_quantizer is not None: + quantization_serializable = ( + hf_quantizer is not None + and isinstance(hf_quantizer, DiffusersQuantizer) + and hf_quantizer.is_serializable + ) + if not quantization_serializable: + raise ValueError( + f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from" + " the logger on the traceback to understand the reason why the quantized model is not serializable." + ) + weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME weights_name = _add_variant(weights_name, variant) weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace( @@ -402,6 +422,18 @@ def save_pretrained( create_pr=create_pr, ) + def dequantize(self): + """ + Potentially dequantize the model in case it has been quantized by a quantization method that support + dequantization. + """ + hf_quantizer = getattr(self, "hf_quantizer", None) + + if hf_quantizer is None: + raise ValueError("You need to first quantize your model in order to dequantize it") + + return hf_quantizer.dequantize(self) + @classmethod @validate_hf_hub_args def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): @@ -524,6 +556,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) + quantization_config = kwargs.pop("quantization_config", None) allow_pickle = False if use_safetensors is None: @@ -618,6 +651,60 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P user_agent=user_agent, **kwargs, ) + # no in-place modification of the original config. + config = copy.deepcopy(config) + + # determine initial quantization config. + ####################################### + pre_quantized = "quantization_config" in config and config["quantization_config"] is not None + if pre_quantized or quantization_config is not None: + if pre_quantized: + config["quantization_config"] = DiffusersAutoQuantizer.merge_quantization_configs( + config["quantization_config"], quantization_config + ) + else: + config["quantization_config"] = quantization_config + hf_quantizer = DiffusersAutoQuantizer.from_config( + config["quantization_config"], pre_quantized=pre_quantized + ) + else: + hf_quantizer = None + + if hf_quantizer is not None: + if device_map is not None: + raise NotImplementedError( + "Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future." + ) + hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map) + torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) + + # In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry` + user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value + + # Force-set to `True` for more mem efficiency + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.") + elif not low_cpu_mem_usage: + raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.") + + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + (torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") + ) + if use_keep_in_fp32_modules: + keep_in_fp32_modules = cls._keep_in_fp32_modules + if not isinstance(keep_in_fp32_modules, list): + keep_in_fp32_modules = [keep_in_fp32_modules] + + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.") + elif not low_cpu_mem_usage: + raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.") + else: + keep_in_fp32_modules = [] + ####################################### # Determine if we're loading from a directory of sharded checkpoints. is_sharded = False @@ -684,6 +771,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P revision=revision, subfolder=subfolder or "", ) + if hf_quantizer is not None: + model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) + logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") + is_sharded = False elif use_safetensors and not is_sharded: try: @@ -729,13 +820,30 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P with accelerate.init_empty_weights(): model = cls.from_config(config, **unused_kwargs) + if hf_quantizer is not None: + hf_quantizer.preprocess_model( + model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules + ) + # if device_map is None, load the state dict and move the params from meta device to the cpu if device_map is None and not is_sharded: - param_device = "cpu" + # `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None. + # It would error out during the `validate_environment()` call above in the absence of cuda. + is_quant_method_bnb = ( + getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ) + if hf_quantizer is None: + param_device = "cpu" + # TODO (sayakpaul, SunMarc): remove this after model loading refactor + elif is_quant_method_bnb: + param_device = torch.cuda.current_device() state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict) + # move the params from meta device to cpu missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if hf_quantizer is not None: + missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="") if len(missing_keys) > 0: raise ValueError( f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" @@ -750,6 +858,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P device=param_device, dtype=torch_dtype, model_name_or_path=pretrained_model_name_or_path, + hf_quantizer=hf_quantizer, + keep_in_fp32_modules=keep_in_fp32_modules, ) if cls._keys_to_ignore_on_load_unexpected is not None: @@ -765,7 +875,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # Load weights and dispatch according to the device_map # by default the device_map is None and the weights are loaded on the CPU force_hook = True - device_map = _determine_device_map(model, device_map, max_memory, torch_dtype) + device_map = _determine_device_map( + model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer + ) if device_map is None and is_sharded: # we load the parameters on the cpu device_map = {"": "cpu"} @@ -843,14 +955,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "error_msgs": error_msgs, } + if hf_quantizer is not None: + hf_quantizer.postprocess_model(model) + model.hf_quantizer = hf_quantizer + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): raise ValueError( f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) - elif torch_dtype is not None: + # When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will + # completely lose the effectivity of `use_keep_in_fp32_modules`. + elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules: model = model.to(torch_dtype) - model.register_to_config(_name_or_path=pretrained_model_name_or_path) + if hf_quantizer is not None: + # We also make sure to purge `_pre_quantization_dtype` when we serialize + # the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable. + model.register_to_config(_name_or_path=pretrained_model_name_or_path, _pre_quantization_dtype=torch_dtype) + else: + model.register_to_config(_name_or_path=pretrained_model_name_or_path) # Set model in evaluation mode to deactivate DropOut modules by default model.eval() @@ -859,6 +982,76 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P return model + # Adapted from `transformers`. + @wraps(torch.nn.Module.cuda) + def cuda(self, *args, **kwargs): + # Checks if the model has been loaded in 4-bit or 8-bit with BNB + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "Calling `cuda()` is not supported for `8-bit` quantized models. " + " Please use the model as it is, since the model has already been set to the correct devices." + ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + return super().cuda(*args, **kwargs) + + # Adapted from `transformers`. + @wraps(torch.nn.Module.to) + def to(self, *args, **kwargs): + dtype_present_in_args = "dtype" in kwargs + + if not dtype_present_in_args: + for arg in args: + if isinstance(arg, torch.dtype): + dtype_present_in_args = True + break + + # Checks if the model has been loaded in 4-bit or 8-bit with BNB + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + if dtype_present_in_args: + raise ValueError( + "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the" + " desired `dtype` by passing the correct `torch_dtype` argument." + ) + + if getattr(self, "is_loaded_in_8bit", False): + raise ValueError( + "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + elif is_bitsandbytes_version("<", "0.43.2"): + raise ValueError( + "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. " + f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2." + ) + return super().to(*args, **kwargs) + + # Taken from `transformers`. + def half(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.half()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been cast to the correct `dtype`." + ) + else: + return super().half(*args) + + # Taken from `transformers`. + def float(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.float()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been cast to the correct `dtype`." + ) + else: + return super().float(*args) + @classmethod def _load_pretrained_model( cls, @@ -1041,19 +1234,63 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool 859520964 ``` """ + is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False) + + if is_loaded_in_4bit: + if is_bitsandbytes_available(): + import bitsandbytes as bnb + else: + raise ValueError( + "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong" + " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. " + ) if exclude_embeddings: embedding_param_names = [ - f"{name}.weight" - for name, module_type in self.named_modules() - if isinstance(module_type, torch.nn.Embedding) + f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding) ] - non_embedding_parameters = [ + total_parameters = [ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names ] - return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) else: - return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + total_parameters = list(self.parameters()) + + total_numel = [] + + for param in total_parameters: + if param.requires_grad or not only_trainable: + # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are + # used for the 4bit quantization (uint8 tensors are stored) + if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit): + if hasattr(param, "element_size"): + num_bytes = param.element_size() + elif hasattr(param, "quant_storage"): + num_bytes = param.quant_storage.itemsize + else: + num_bytes = 1 + total_numel.append(param.numel() * 2 * num_bytes) + else: + total_numel.append(param.numel()) + + return sum(total_numel) + + def get_memory_footprint(self, return_buffers=True): + r""" + Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. + Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the + PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 + + Arguments: + return_buffers (`bool`, *optional*, defaults to `True`): + Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers + are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch + norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 + """ + mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) + if return_buffers: + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) + mem = mem + mem_bufs + return mem def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None: deprecated_attention_block_paths = [] diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 21e9d3cd6fc5..029c147fcbac 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -97,6 +97,40 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: ).to(origin_dtype) +class SD35AdaLayerNormZeroX(nn.Module): + r""" + Norm layer adaptive layer norm zero (AdaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None: + super().__init__() + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias) + if norm_type == "layer_norm": + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + else: + raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.") + + def forward( + self, + hidden_states: torch.Tensor, + emb: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, ...]: + emb = self.linear(self.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk( + 9, dim=1 + ) + norm_hidden_states = self.norm(hidden_states) + hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None] + norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None] + return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 + + class AdaLayerNormZero(nn.Module): r""" Norm layer adaptive layer norm zero (adaLN-Zero). diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 9376c91d0756..b28350b8ed9c 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -69,6 +69,10 @@ def __init__( pooled_projection_dim: int = 2048, out_channels: int = 16, pos_embed_max_size: int = 96, + dual_attention_layers: Tuple[ + int, ... + ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5 + qk_norm: Optional[str] = None, ): super().__init__() default_out_channels = in_channels @@ -97,6 +101,8 @@ def __init__( num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, context_pre_only=i == num_layers - 1, + qk_norm=qk_norm, + use_dual_attention=True if i in dual_attention_layers else False, ) for i in range(self.config.num_layers) ] diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index d7ff34310beb..7366520f4692 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -144,6 +144,7 @@ "CogVideoXPipeline", "CogVideoXImageToVideoPipeline", "CogVideoXVideoToVideoPipeline", + "CogVideoXFunControlPipeline", ] _import_structure["cogview3"] = ["CogView3PlusPipeline"] _import_structure["controlnet"].extend( @@ -470,7 +471,12 @@ ) from .aura_flow import AuraFlowPipeline from .blip_diffusion import BlipDiffusionPipeline - from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline + from .cogvideo import ( + CogVideoXFunControlPipeline, + CogVideoXImageToVideoPipeline, + CogVideoXPipeline, + CogVideoXVideoToVideoPipeline, + ) from .cogview3 import CogView3PlusPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 6fe5c5604c86..181448fc2f5e 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -113,9 +113,21 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) @@ -135,7 +147,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index de5a16b05b40..c7ff97ce4226 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -119,7 +119,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py index 1d26f95a2f58..9a93f1d28d35 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py @@ -131,7 +131,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 1e3ba67adf5d..ae55b3cab17e 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -54,7 +54,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/cogvideo/__init__.py b/src/diffusers/pipelines/cogvideo/__init__.py index bd60fcea9994..e4fa1dda53d3 100644 --- a/src/diffusers/pipelines/cogvideo/__init__.py +++ b/src/diffusers/pipelines/cogvideo/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"] + _import_structure["pipeline_cogvideox_fun_control"] = ["CogVideoXFunControlPipeline"] _import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"] _import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"] @@ -35,6 +36,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_cogvideox import CogVideoXPipeline + from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index fdafdbc7c019..44ed94333e4a 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -86,7 +86,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. @@ -521,7 +521,7 @@ def __call__( num_frames (`int`, defaults to `48`): Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where - num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that needs to be satisfied is that of divisibility mentioned above. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py new file mode 100644 index 000000000000..3655075bd519 --- /dev/null +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -0,0 +1,794 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI, Alibaba-PAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from PIL import Image +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import CogVideoXLoraLoaderMixin +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXFunControlPipeline, DDIMScheduler + >>> from diffusers.utils import export_to_video, load_video + + >>> pipe = CogVideoXFunControlPipeline.from_pretrained( + ... "alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + >>> pipe.to("cuda") + + >>> control_video = load_video( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" + ... ) + >>> prompt = ( + ... "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and " + ... "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in " + ... "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, " + ... "moons, but the remainder of the scene is mostly realistic." + ... ) + + >>> video = pipe(prompt=prompt, control_video=control_video).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): + r""" + Pipeline for controlled text-to-video generation using CogVideoX Fun. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->vae->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + self.vae_scaling_factor_image = ( + self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.prepare_latents + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Adapted from https://github.com/aigc-apps/CogVideoX-Fun/blob/2a93e5c14e02b2b5921d533fd59fc8c0ed69fb24/cogvideox/pipeline/pipeline_cogvideox_control.py#L366 + def prepare_control_latents( + self, mask: Optional[torch.Tensor] = None, masked_image: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + if mask is not None: + masks = [] + for i in range(mask.size(0)): + current_mask = mask[i].unsqueeze(0) + current_mask = self.vae.encode(current_mask)[0] + current_mask = current_mask.mode() + masks.append(current_mask) + mask = torch.cat(masks, dim=0) + mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + mask_pixel_values = [] + for i in range(masked_image.size(0)): + mask_pixel_value = masked_image[i].unsqueeze(0) + mask_pixel_value = self.vae.encode(mask_pixel_value)[0] + mask_pixel_value = mask_pixel_value.mode() + mask_pixel_values.append(mask_pixel_value) + masked_image_latents = torch.cat(mask_pixel_values, dim=0) + masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae_scaling_factor_image * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + control_video=None, + control_video_latents=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if control_video is not None and control_video_latents is not None: + raise ValueError( + "Cannot pass both `control_video` and `control_video_latents`. Please make sure to pass only one of these parameters." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + control_video: Optional[List[Image.Image]] = None, + height: int = 480, + width: int = 720, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + control_video_latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: str = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + control_video (`List[PIL.Image.Image]`): + The control video to condition the generation on. Must be a list of images/frames of the video. If not + provided, `control_video_latents` must be provided. + height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial): + The width in pixels of the generated image. This is set to 720 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 6.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + control_video_latents (`torch.Tensor`, *optional*): + Pre-generated control latents, sampled from a Gaussian distribution, to be used as inputs for + controlled video generation. If not provided, `control_video` must be provided. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + control_video, + control_video_latents, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if control_video is not None and isinstance(control_video[0], Image.Image): + control_video = [control_video] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels // 2 + num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2) + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + if control_video_latents is None: + control_video = self.video_processor.preprocess_video(control_video, height=height, width=width) + control_video = control_video.to(device=device, dtype=prompt_embeds.dtype) + + _, control_video_latents = self.prepare_control_latents(None, control_video) + control_video_latents = control_video_latents.permute(0, 2, 1, 3, 4) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + latent_control_input = ( + torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents + ) + latent_model_input = torch.cat([latent_model_input, latent_control_input], dim=2) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 975e8ed27db8..783dae569bec 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -88,7 +88,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. @@ -604,7 +604,7 @@ def __call__( num_frames (`int`, defaults to `48`): Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where - num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that needs to be satisfied is that of divisibility mentioned above. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 35f8f2fa0508..e1e816eca16d 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -94,7 +94,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py index 7ae86421c45e..64fff61d2c32 100644 --- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py +++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py @@ -57,7 +57,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index b06ba4e7cba1..4e80aa5a4b6a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -101,7 +101,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index fa753bbfa98b..beb7d1c15d94 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -137,9 +137,21 @@ def retrieve_latents( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 67a96493a2eb..df652239df63 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -122,7 +122,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. @@ -242,6 +242,7 @@ class StableDiffusionXLControlNetPipeline( "add_time_ids", "negative_pooled_prompt_embeds", "negative_add_time_ids", + "image", ] def __init__( @@ -1541,6 +1542,7 @@ def __call__( ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + image = callback_outputs.pop("image", image) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py index 8e738c456802..0be580eb7f8a 100644 --- a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +++ b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py @@ -141,9 +141,21 @@ def get_resize_crop_region_for_grid(src, tgt_size): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index 02b4e46ab182..1764675df70b 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -83,7 +83,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index 3b099721eb2d..170e1fc54deb 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -108,7 +108,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py index 0d4ca3799a53..1f7f4db728ad 100644 --- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py @@ -65,9 +65,21 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) @@ -87,7 +99,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py index 54823937f58e..b3635be77ab4 100644 --- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -127,7 +127,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index f2040b184eac..e8f01273828a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -20,7 +20,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -86,7 +86,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. @@ -137,7 +137,12 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin): +class FluxPipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, +): r""" The Flux pipeline for text-to-image generation. @@ -212,6 +217,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -255,6 +263,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index cb47b9be95d0..efefb80bc9dc 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -25,7 +25,7 @@ ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel @@ -106,7 +106,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. @@ -238,6 +238,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -281,6 +284,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index a7f7c66a2cad..8d636feeae05 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -11,7 +11,7 @@ ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel @@ -118,7 +118,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. @@ -251,6 +251,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -295,6 +298,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", @@ -897,9 +903,12 @@ def __call__( timestep = t.expand(latents.shape[0]).to(latents.dtype) - guidance = ( - torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None - ) + if isinstance(self.controlnet, FluxMultiControlNetModel): + use_guidance = self.controlnet.nets[0].config.guidance_embeds + else: + use_guidance = self.controlnet.config.guidance_embeds + + guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None guidance = guidance.expand(latents.shape[0]) if guidance is not None else None if isinstance(controlnet_keep[i], list): diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 50d2fcaa7fa5..46784f2d46d1 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -12,7 +12,7 @@ ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin +from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel from ...models.transformers import FluxTransformer2DModel @@ -120,7 +120,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. @@ -261,6 +261,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -305,6 +308,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index 6375b1a994a1..14b062180003 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -20,7 +20,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin +from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -108,7 +108,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. @@ -235,6 +235,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -279,6 +282,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index b9114c923cba..2f0be67bf88c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -21,7 +21,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FluxLoraLoaderMixin +from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import FluxTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -105,7 +105,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. @@ -239,6 +239,9 @@ def _get_t5_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) + text_inputs = self.tokenizer_2( prompt, padding="max_length", @@ -283,6 +286,9 @@ def _get_clip_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + text_inputs = self.tokenizer( prompt, padding="max_length", diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py index 40f20fe55bfb..304d6ed2a6c9 100644 --- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -125,9 +125,21 @@ def get_resize_crop_region_for_grid(src, tgt_size): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors.py b/src/diffusers/pipelines/kolors/pipeline_kolors.py index 6f25fd1ac0be..a55d6b97f151 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors.py @@ -70,7 +70,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py index 145daf0181fa..6433b2f1f7f1 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -89,7 +89,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index dd72d3c9e10e..e985648abace 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -66,7 +66,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py index 89cafc2877fe..d110cd464522 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py @@ -70,7 +70,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index 8b78a7e75681..b0aa298b34ef 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -76,7 +76,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py index 537b509e3720..83b3753c2c87 100644 --- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py @@ -234,9 +234,21 @@ def __call__( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py index 995cc15f3f93..834445bfcd06 100644 --- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py @@ -1643,9 +1643,21 @@ def invert( # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 1ded6013fb35..9a4b37972379 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -73,7 +73,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py index 92a648a0d194..f3e4e94c836d 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py @@ -104,7 +104,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py index debcdad85656..6ccf11a3b774 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -126,7 +126,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py index ebcff0ea269a..7a88451273d4 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -128,9 +128,21 @@ def get_resize_crop_region_for_grid(src, tgt_size): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py index 4dc01b7caa64..7c9b6e1e065d 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py @@ -75,7 +75,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py index 8e5e6cbaf5ad..59d6a9001e1f 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py @@ -81,7 +81,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_sd.py index 6a3efb71f8e0..0b0bd5e71aa3 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd.py @@ -60,9 +60,21 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) @@ -82,7 +94,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py index 2f0d997359a1..e5bb01dd2409 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py @@ -82,7 +82,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py index e4f26494d5c3..49dc4948cb40 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py @@ -89,7 +89,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py index 904942528a36..d850df080fb5 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -88,9 +88,21 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) @@ -110,7 +122,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index 2b17537dcfc6..74b7ede6a8d6 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -92,9 +92,21 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) @@ -128,7 +140,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py index 7c2ec2e726ad..e36c1dc77396 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -105,9 +105,21 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) @@ -141,7 +153,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 0a744264b7a6..5eba1952e608 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -118,6 +118,10 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No components.setdefault(component, []) components[component].append(component_filename) + # If there are no component folders check the main directory for safetensors files + if not components: + return any(".safetensors" in filename for filename in filenames) + # iterate over all files of a component # check if safetensor files exist for that component # if variant is provided check if the variant of the safetensors exists @@ -838,3 +842,108 @@ def get_connected_passed_kwargs(prefix): ) return init_kwargs + + +def _get_custom_components_and_folders( + pretrained_model_name: str, + config_dict: Dict[str, Any], + filenames: Optional[List[str]] = None, + variant_filenames: Optional[List[str]] = None, + variant: Optional[str] = None, +): + config_dict = config_dict.copy() + + # retrieve all folder_names that contain relevant files + folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] + + diffusers_module = importlib.import_module(__name__.split(".")[0]) + pipelines = getattr(diffusers_module, "pipelines") + + # optionally create a custom component <> custom file mapping + custom_components = {} + for component in folder_names: + module_candidate = config_dict[component][0] + + if module_candidate is None or not isinstance(module_candidate, str): + continue + + # We compute candidate file path on the Hub. Do not use `os.path.join`. + candidate_file = f"{component}/{module_candidate}.py" + + if candidate_file in filenames: + custom_components[component] = module_candidate + elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate): + raise ValueError( + f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." + ) + + if len(variant_filenames) == 0 and variant is not None: + error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." + raise ValueError(error_message) + + return custom_components, folder_names + + +def _get_ignore_patterns( + passed_components, + model_folder_names: List[str], + model_filenames: List[str], + variant_filenames: List[str], + use_safetensors: bool, + from_flax: bool, + allow_pickle: bool, + use_onnx: bool, + is_onnx: bool, + variant: Optional[str] = None, +) -> List[str]: + if ( + use_safetensors + and not allow_pickle + and not is_safetensors_compatible( + model_filenames, passed_components=passed_components, folder_names=model_folder_names + ) + ): + raise EnvironmentError( + f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})" + ) + + if from_flax: + ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] + + elif use_safetensors and is_safetensors_compatible( + model_filenames, passed_components=passed_components, folder_names=model_folder_names + ): + ignore_patterns = ["*.bin", "*.msgpack"] + + use_onnx = use_onnx if use_onnx is not None else is_onnx + if not use_onnx: + ignore_patterns += ["*.onnx", "*.pb"] + + safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} + safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} + if len(safetensors_variant_filenames) > 0 and safetensors_model_filenames != safetensors_variant_filenames: + logger.warning( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" + f"[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n" + f"[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not " + f"expected, please check your folder structure." + ) + + else: + ignore_patterns = ["*.safetensors", "*.msgpack"] + + use_onnx = use_onnx if use_onnx is not None else is_onnx + if not use_onnx: + ignore_patterns += ["*.onnx", "*.pb"] + + bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} + bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} + if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: + logger.warning( + f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n" + f"[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n" + f"[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check " + f"your folder structure." + ) + + return ignore_patterns diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 857a13147cfe..2e1858b16148 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -44,6 +44,7 @@ from ..models import AutoencoderKL from ..models.attention_processor import FusedAttnProcessor2_0 from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin +from ..quantizers.bitsandbytes.utils import _check_bnb_status from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( CONFIG_NAME, @@ -54,6 +55,7 @@ is_accelerate_version, is_torch_npu_available, is_torch_version, + is_transformers_version, logging, numpy_to_pil, ) @@ -71,15 +73,16 @@ CUSTOM_PIPELINE_FILE_NAME, LOADABLE_CLASSES, _fetch_class_library_tuple, + _get_custom_components_and_folders, _get_custom_pipeline_class, _get_final_device_map, + _get_ignore_patterns, _get_pipeline_class, _identify_model_variants, _maybe_raise_warning_for_inpainting, _resolve_custom_pipeline_and_cls, _unwrap_model, _update_init_kwargs_with_connected_pipeline, - is_safetensors_compatible, load_sub_model, maybe_raise_or_warn, variant_compatible_siblings, @@ -431,18 +434,23 @@ def module_is_offloaded(module): is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded for module in modules: - is_loaded_in_8bit = hasattr(module, "is_loaded_in_8bit") and module.is_loaded_in_8bit + _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module) - if is_loaded_in_8bit and dtype is not None: + if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None: logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision." + f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision." ) - if is_loaded_in_8bit and device is not None: + if is_loaded_in_8bit_bnb and device is not None: logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}." + f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}." ) - else: + + # This can happen for `transformer` models. CPU placement was added in + # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly. + if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"): + module.to(device=device) + elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb: module.to(device, dtype) if ( @@ -1039,9 +1047,18 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t hook = None for model_str in self.model_cpu_offload_seq.split("->"): model = all_model_components.pop(model_str, None) + if not isinstance(model, torch.nn.Module): continue + # This is because the model would already be placed on a CUDA device. + _, _, is_loaded_in_8bit_bnb = _check_bnb_status(model) + if is_loaded_in_8bit_bnb: + logger.info( + f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit." + ) + continue + _, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook) self._all_hooks.append(hook) @@ -1298,44 +1315,18 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: config_dict = cls._dict_from_json_file(config_file) ignore_filenames = config_dict.pop("_ignore_files", []) - # retrieve all folder_names that contain relevant files - folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"] - - diffusers_module = importlib.import_module(__name__.split(".")[0]) - pipelines = getattr(diffusers_module, "pipelines") - - # optionally create a custom component <> custom file mapping - custom_components = {} - for component in folder_names: - module_candidate = config_dict[component][0] - - if module_candidate is None or not isinstance(module_candidate, str): - continue - - # We compute candidate file path on the Hub. Do not use `os.path.join`. - candidate_file = f"{component}/{module_candidate}.py" - - if candidate_file in filenames: - custom_components[component] = module_candidate - elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate): - raise ValueError( - f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'." - ) - - if len(variant_filenames) == 0 and variant is not None: - error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available." - raise ValueError(error_message) - # remove ignored filenames model_filenames = set(model_filenames) - set(ignore_filenames) variant_filenames = set(variant_filenames) - set(ignore_filenames) - # if the whole pipeline is cached we don't have to ping the Hub if revision in DEPRECATED_REVISION_ARGS and version.parse( version.parse(__version__).base_version ) >= version.parse("0.22.0"): warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames) + custom_components, folder_names = _get_custom_components_and_folders( + pretrained_model_name, config_dict, filenames, variant_filenames, variant + ) model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names} custom_class_name = None @@ -1395,49 +1386,19 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: expected_components, _ = cls._get_signature_keys(pipeline_class) passed_components = [k for k in expected_components if k in kwargs] - if ( - use_safetensors - and not allow_pickle - and not is_safetensors_compatible( - model_filenames, passed_components=passed_components, folder_names=model_folder_names - ) - ): - raise EnvironmentError( - f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})" - ) - if from_flax: - ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"] - elif use_safetensors and is_safetensors_compatible( - model_filenames, passed_components=passed_components, folder_names=model_folder_names - ): - ignore_patterns = ["*.bin", "*.msgpack"] - - use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx - if not use_onnx: - ignore_patterns += ["*.onnx", "*.pb"] - - safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} - safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} - if ( - len(safetensors_variant_filenames) > 0 - and safetensors_model_filenames != safetensors_variant_filenames - ): - logger.warning( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." - ) - else: - ignore_patterns = ["*.safetensors", "*.msgpack"] - - use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx - if not use_onnx: - ignore_patterns += ["*.onnx", "*.pb"] - - bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} - bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} - if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: - logger.warning( - f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure." - ) + # retrieve all patterns that should not be downloaded and error out when needed + ignore_patterns = _get_ignore_patterns( + passed_components, + model_folder_names, + model_filenames, + variant_filenames, + use_safetensors, + from_flax, + allow_pickle, + use_onnx, + pipeline_class._is_onnx, + variant, + ) # Don't download any objects that are passed allow_patterns = [ diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 5b220df8058b..46d8ad5e6dfa 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -178,7 +178,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index 69f028914774..b2772d552514 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -122,7 +122,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 5b5ff23df18e..092785017965 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -69,9 +69,21 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) @@ -90,7 +102,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. @@ -149,7 +161,7 @@ class StableDiffusionPipeline( IPAdapterMixin, FromSingleFileMixin, ): - r""" + """ Pipeline for text-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods @@ -1066,7 +1078,6 @@ def __call__( do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 0b368ba2ac13..9c90b910e007 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -119,7 +119,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 21c6642b5e9f..4cfe95d0d632 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -60,7 +60,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index f76230b1eeac..e58733ad5397 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -77,7 +77,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 6a82ed9f48f5..cef9804e821d 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -98,7 +98,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index 440b6529c9ca..7401be39d6f9 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -97,7 +97,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py index 95532efc6d2c..50d25c08b2f5 100644 --- a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py @@ -61,9 +61,21 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) @@ -83,7 +95,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index f6fabeeca011..9337c874faeb 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -61,9 +61,21 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) @@ -83,7 +95,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index c99ce7ef03eb..3f8c7785957e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -87,9 +87,21 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) @@ -109,7 +121,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index ddb31f28c472..8b53841cb138 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -90,9 +90,21 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) @@ -126,7 +138,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 199f5cabaa32..6a5fb284c39d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -101,9 +101,21 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) @@ -153,7 +165,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index dc283108d143..fe3c9afbba1d 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -71,7 +71,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 3cb7c26bb6a2..1a938aaf9423 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -127,7 +127,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 0ea197e42e62..20569d0adb32 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -119,9 +119,21 @@ def _preprocess_adapter_image(image, height, width): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) @@ -141,7 +153,7 @@ def retrieve_timesteps( sigmas: Optional[List[float]] = None, **kwargs, ): - """ + r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py index c46a5ce7c084..9ff473cc3a38 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py @@ -310,9 +310,21 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + r""" + Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on + Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). + + Args: + noise_cfg (`torch.Tensor`): + The predicted noise tensor for the guided diffusion process. + noise_pred_text (`torch.Tensor`): + The predicted noise tensor for the text-guided diffusion process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + A rescale factor applied to the noise predictions. + + Returns: + noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor. """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py new file mode 100644 index 000000000000..4c8483a3d6ee --- /dev/null +++ b/src/diffusers/quantizers/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .auto import DiffusersAutoQuantizer +from .base import DiffusersQuantizer diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py new file mode 100644 index 000000000000..97cbcdc0e53f --- /dev/null +++ b/src/diffusers/quantizers/auto.py @@ -0,0 +1,126 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adapted from +https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py +""" +import warnings +from typing import Dict, Optional, Union + +from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer +from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod + + +AUTO_QUANTIZER_MAPPING = { + "bitsandbytes_4bit": BnB4BitDiffusersQuantizer, + "bitsandbytes_8bit": BnB8BitDiffusersQuantizer, +} + +AUTO_QUANTIZATION_CONFIG_MAPPING = { + "bitsandbytes_4bit": BitsAndBytesConfig, + "bitsandbytes_8bit": BitsAndBytesConfig, +} + + +class DiffusersAutoQuantizer: + """ + The auto diffusers quantizer class that takes care of automatically instantiating to the correct + `DiffusersQuantizer` given the `QuantizationConfig`. + """ + + @classmethod + def from_dict(cls, quantization_config_dict: Dict): + quant_method = quantization_config_dict.get("quant_method", None) + # We need a special care for bnb models to make sure everything is BC .. + if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False): + suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit" + quant_method = QuantizationMethod.BITS_AND_BYTES + suffix + elif quant_method is None: + raise ValueError( + "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized" + ) + + if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys(): + raise ValueError( + f"Unknown quantization type, got {quant_method} - supported types are:" + f" {list(AUTO_QUANTIZER_MAPPING.keys())}" + ) + + target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method] + return target_cls.from_dict(quantization_config_dict) + + @classmethod + def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs): + # Convert it to a QuantizationConfig if the q_config is a dict + if isinstance(quantization_config, dict): + quantization_config = cls.from_dict(quantization_config) + + quant_method = quantization_config.quant_method + + # Again, we need a special care for bnb as we have a single quantization config + # class for both 4-bit and 8-bit quantization + if quant_method == QuantizationMethod.BITS_AND_BYTES: + if quantization_config.load_in_8bit: + quant_method += "_8bit" + else: + quant_method += "_4bit" + + if quant_method not in AUTO_QUANTIZER_MAPPING.keys(): + raise ValueError( + f"Unknown quantization type, got {quant_method} - supported types are:" + f" {list(AUTO_QUANTIZER_MAPPING.keys())}" + ) + + target_cls = AUTO_QUANTIZER_MAPPING[quant_method] + return target_cls(quantization_config, **kwargs) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + model_config = cls.load_config(pretrained_model_name_or_path, **kwargs) + if getattr(model_config, "quantization_config", None) is None: + raise ValueError( + f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized." + ) + quantization_config_dict = model_config.quantization_config + quantization_config = cls.from_dict(quantization_config_dict) + # Update with potential kwargs that are passed through from_pretrained. + quantization_config.update(kwargs) + + return cls.from_config(quantization_config) + + @classmethod + def merge_quantization_configs( + cls, + quantization_config: Union[dict, QuantizationConfigMixin], + quantization_config_from_args: Optional[QuantizationConfigMixin], + ): + """ + handles situations where both quantization_config from args and quantization_config from model config are + present. + """ + if quantization_config_from_args is not None: + warning_msg = ( + "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading" + " already has a `quantization_config` attribute. The `quantization_config` from the model will be used." + ) + else: + warning_msg = "" + + if isinstance(quantization_config, dict): + quantization_config = cls.from_dict(quantization_config) + + if warning_msg != "": + warnings.warn(warning_msg) + + return quantization_config diff --git a/src/diffusers/quantizers/base.py b/src/diffusers/quantizers/base.py new file mode 100644 index 000000000000..6ec3885fe373 --- /dev/null +++ b/src/diffusers/quantizers/base.py @@ -0,0 +1,233 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adapted from +https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/quantizers/base.py +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from ..utils import is_torch_available +from .quantization_config import QuantizationConfigMixin + + +if TYPE_CHECKING: + from ..models.modeling_utils import ModelMixin + +if is_torch_available(): + import torch + + +class DiffusersQuantizer(ABC): + """ + Abstract class of the HuggingFace quantizer. Supports for now quantizing HF diffusers models for inference and/or + quantization. This class is used only for diffusers.models.modeling_utils.ModelMixin.from_pretrained and cannot be + easily used outside the scope of that method yet. + + Attributes + quantization_config (`diffusers.quantizers.quantization_config.QuantizationConfigMixin`): + The quantization config that defines the quantization parameters of your model that you want to quantize. + modules_to_not_convert (`List[str]`, *optional*): + The list of module names to not convert when quantizing the model. + required_packages (`List[str]`, *optional*): + The list of required pip packages to install prior to using the quantizer + requires_calibration (`bool`): + Whether the quantization method requires to calibrate the model before using it. + """ + + requires_calibration = False + required_packages = None + + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): + self.quantization_config = quantization_config + + # -- Handle extra kwargs below -- + self.modules_to_not_convert = kwargs.pop("modules_to_not_convert", []) + self.pre_quantized = kwargs.pop("pre_quantized", True) + + if not self.pre_quantized and self.requires_calibration: + raise ValueError( + f"The quantization method {quantization_config.quant_method} does require the model to be pre-quantized." + f" You explicitly passed `pre_quantized=False` meaning your model weights are not quantized. Make sure to " + f"pass `pre_quantized=True` while knowing what you are doing." + ) + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + """ + Some quantization methods require to explicitly set the dtype of the model to a target dtype. You need to + override this method in case you want to make sure that behavior is preserved + + Args: + torch_dtype (`torch.dtype`): + The input dtype that is passed in `from_pretrained` + """ + return torch_dtype + + def update_device_map(self, device_map: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """ + Override this method if you want to pass a override the existing device map with a new one. E.g. for + bitsandbytes, since `accelerate` is a hard requirement, if no device_map is passed, the device_map is set to + `"auto"`` + + Args: + device_map (`Union[dict, str]`, *optional*): + The device_map that is passed through the `from_pretrained` method. + """ + return device_map + + def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + """ + Override this method if you want to adjust the `target_dtype` variable used in `from_pretrained` to compute the + device_map in case the device_map is a `str`. E.g. for bitsandbytes we force-set `target_dtype` to `torch.int8` + and for 4-bit we pass a custom enum `accelerate.CustomDtype.int4`. + + Args: + torch_dtype (`torch.dtype`, *optional*): + The torch_dtype that is used to compute the device_map. + """ + return torch_dtype + + def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: + """ + Override this method if you want to adjust the `missing_keys`. + + Args: + missing_keys (`List[str]`, *optional*): + The list of missing keys in the checkpoint compared to the state dict of the model + """ + return missing_keys + + def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]: + """ + returns dtypes for modules that are not quantized - used for the computation of the device_map in case one + passes a str as a device_map. The method will use the `modules_to_not_convert` that is modified in + `_process_model_before_weight_loading`. `diffusers` models don't have any `modules_to_not_convert` attributes + yet but this can change soon in the future. + + Args: + model (`~diffusers.models.modeling_utils.ModelMixin`): + The model to quantize + torch_dtype (`torch.dtype`): + The dtype passed in `from_pretrained` method. + """ + + return { + name: torch_dtype + for name, _ in model.named_parameters() + if any(m in name for m in self.modules_to_not_convert) + } + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + """adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization""" + return max_memory + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + """ + checks if a loaded state_dict component is part of quantized param + some validation; only defined for + quantization methods that require to create a new parameters for quantization. + """ + return False + + def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter": + """ + takes needed components from state_dict and creates quantized param. + """ + return + + def check_quantized_param_shape(self, *args, **kwargs): + """ + checks if the quantized param has expected shape. + """ + return True + + def validate_environment(self, *args, **kwargs): + """ + This method is used to potentially check for potential conflicts with arguments that are passed in + `from_pretrained`. You need to define it for all future quantizers that are integrated with diffusers. If no + explicit check are needed, simply return nothing. + """ + return + + def preprocess_model(self, model: "ModelMixin", **kwargs): + """ + Setting model attributes and/or converting model before weights loading. At this point the model should be + initialized on the meta device so you can freely manipulate the skeleton of the model in order to replace + modules in-place. Make sure to override the abstract method `_process_model_before_weight_loading`. + + Args: + model (`~diffusers.models.modeling_utils.ModelMixin`): + The model to quantize + kwargs (`dict`, *optional*): + The keyword arguments that are passed along `_process_model_before_weight_loading`. + """ + model.is_quantized = True + model.quantization_method = self.quantization_config.quant_method + return self._process_model_before_weight_loading(model, **kwargs) + + def postprocess_model(self, model: "ModelMixin", **kwargs): + """ + Post-process the model post weights loading. Make sure to override the abstract method + `_process_model_after_weight_loading`. + + Args: + model (`~diffusers.models.modeling_utils.ModelMixin`): + The model to quantize + kwargs (`dict`, *optional*): + The keyword arguments that are passed along `_process_model_after_weight_loading`. + """ + return self._process_model_after_weight_loading(model, **kwargs) + + def dequantize(self, model): + """ + Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance. Note + not all quantization schemes support this. + """ + model = self._dequantize(model) + + # Delete quantizer and quantization config + del model.hf_quantizer + + return model + + def _dequantize(self, model): + raise NotImplementedError( + f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub." + ) + + @abstractmethod + def _process_model_before_weight_loading(self, model, **kwargs): + ... + + @abstractmethod + def _process_model_after_weight_loading(self, model, **kwargs): + ... + + @property + @abstractmethod + def is_serializable(self): + ... + + @property + @abstractmethod + def is_trainable(self): + ... diff --git a/src/diffusers/quantizers/bitsandbytes/__init__.py b/src/diffusers/quantizers/bitsandbytes/__init__.py new file mode 100644 index 000000000000..9e745bc810fa --- /dev/null +++ b/src/diffusers/quantizers/bitsandbytes/__init__.py @@ -0,0 +1,2 @@ +from .bnb_quantizer import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer +from .utils import dequantize_and_replace, dequantize_bnb_weight, replace_with_bnb_linear diff --git a/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py new file mode 100644 index 000000000000..d5ac1611a571 --- /dev/null +++ b/src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py @@ -0,0 +1,558 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adapted from +https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/quantizer_bnb_4bit.py +""" + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from ...utils import get_module_from_name +from ..base import DiffusersQuantizer + + +if TYPE_CHECKING: + from ...models.modeling_utils import ModelMixin + +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_bitsandbytes_available, + is_bitsandbytes_version, + is_torch_available, + logging, +) + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class BnB4BitDiffusersQuantizer(DiffusersQuantizer): + """ + 4-bit quantization from bitsandbytes.py quantization method: + before loading: converts transformer layers into Linear4bit during loading: load 16bit weight and pass to the + layer object after: quantizes individual weights in Linear4bit into 4bit at the first .cuda() call saving: + from state dict, as usual; saves weights and `quant_state` components + loading: + need to locate `quant_state` components and pass to Param4bit constructor + """ + + use_keep_in_fp32_modules = True + requires_calibration = False + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + if self.quantization_config.llm_int8_skip_modules is not None: + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + def validate_environment(self, *args, **kwargs): + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): + raise ImportError( + "Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" + ) + if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"): + raise ImportError( + "Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" + ) + + if kwargs.get("from_flax", False): + raise ValueError( + "Converting into 4-bit weights from flax weights is currently not supported, please make" + " sure the weights are in PyTorch format." + ) + + device_map = kwargs.get("device_map", None) + if ( + device_map is not None + and isinstance(device_map, dict) + and not self.quantization_config.llm_int8_enable_fp32_cpu_offload + ): + device_map_without_no_convert = { + key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert + } + if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): + raise ValueError( + "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " + "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " + "in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to " + "`from_pretrained`. Check " + "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu " + "for more details. " + ) + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if target_dtype != torch.int8: + from accelerate.utils import CustomDtype + + logger.info("target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization") + return CustomDtype.INT4 + else: + raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.") + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ) -> bool: + import bitsandbytes as bnb + + module, tensor_name = get_module_from_name(model, param_name) + if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Params4bit): + # Add here check for loaded components' dtypes once serialization is implemented + return True + elif isinstance(module, bnb.nn.Linear4bit) and tensor_name == "bias": + # bias could be loaded by regular set_module_tensor_to_device() from accelerate, + # but it would wrongly use uninitialized weight there. + return True + else: + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: Optional[List[str]] = None, + ): + import bitsandbytes as bnb + + module, tensor_name = get_module_from_name(model, param_name) + + if tensor_name not in module._parameters: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + + old_value = getattr(module, tensor_name) + + if tensor_name == "bias": + if param_value is None: + new_value = old_value.to(target_device) + else: + new_value = param_value.to(target_device) + + new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad) + module._parameters[tensor_name] = new_value + return + + if not isinstance(module._parameters[tensor_name], bnb.nn.Params4bit): + raise ValueError("this function only loads `Linear4bit components`") + if ( + old_value.device == torch.device("meta") + and target_device not in ["meta", torch.device("meta")] + and param_value is None + ): + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.") + + # construct `new_value` for the module._parameters[tensor_name]: + if self.pre_quantized: + # 4bit loading. Collecting components for restoring quantized weight + # This can be expanded to make a universal call for any quantized weight loading + + if not self.is_serializable: + raise ValueError( + "Detected int4 weights but the version of bitsandbytes is not compatible with int4 serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + + if (param_name + ".quant_state.bitsandbytes__fp4" not in state_dict) and ( + param_name + ".quant_state.bitsandbytes__nf4" not in state_dict + ): + raise ValueError( + f"Supplied state dict for {param_name} does not contain `bitsandbytes__*` and possibly other `quantized_stats` components." + ) + + quantized_stats = {} + for k, v in state_dict.items(): + # `startswith` to counter for edge cases where `param_name` + # substring can be present in multiple places in the `state_dict` + if param_name + "." in k and k.startswith(param_name): + quantized_stats[k] = v + if unexpected_keys is not None and k in unexpected_keys: + unexpected_keys.remove(k) + + new_value = bnb.nn.Params4bit.from_prequantized( + data=param_value, + quantized_stats=quantized_stats, + requires_grad=False, + device=target_device, + ) + else: + new_value = param_value.to("cpu") + kwargs = old_value.__dict__ + new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(target_device) + + module._parameters[tensor_name] = new_value + + def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape): + n = current_param_shape.numel() + inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1) + if loaded_param_shape != inferred_shape: + raise ValueError( + f"Expected the flattened shape of the current param ({param_name}) to be {loaded_param_shape} but is {inferred_shape}." + ) + else: + return True + + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + # need more space for buffers that are created during quantization + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + logger.info( + "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to " + "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. " + "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass" + " torch_dtype=torch.float16 to remove this warning.", + torch_dtype, + ) + torch_dtype = torch.float16 + return torch_dtype + + # (sayakpaul): I think it could be better to disable custom `device_map`s + # for the first phase of the integration in the interest of simplicity. + # Commenting this for discussions on the PR. + # def update_device_map(self, device_map): + # if device_map is None: + # device_map = {"": torch.cuda.current_device()} + # logger.info( + # "The device_map was not initialized. " + # "Setting device_map to {'':torch.cuda.current_device()}. " + # "If you want to use the model for inference, please set device_map ='auto' " + # ) + # return device_map + + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + from .utils import replace_with_bnb_linear + + load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload + + # We may keep some modules such as the `proj_out` in their original dtype for numerical stability reasons + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + + if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload: + raise ValueError( + "If you want to offload some keys to `cpu` or `disk`, you need to set " + "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be " + " converted to 8-bit but kept in 32-bit." + ) + self.modules_to_not_convert.extend(keys_on_cpu) + + # Purge `None`. + # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 + # in case of diffusion transformer models. For language models and others alike, `lm_head` + # and tied modules are usually kept in FP32. + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] + + model = replace_with_bnb_linear( + model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config + ) + model.config.quantization_config = self.quantization_config + + def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): + model.is_loaded_in_4bit = True + model.is_4bit_serializable = self.is_serializable + return model + + @property + def is_serializable(self): + # Because we're mandating `bitsandbytes` 0.43.3. + return True + + @property + def is_trainable(self) -> bool: + # Because we're mandating `bitsandbytes` 0.43.3. + return True + + def _dequantize(self, model): + from .utils import dequantize_and_replace + + is_model_on_cpu = model.device.type == "cpu" + if is_model_on_cpu: + logger.info( + "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device." + ) + model.to(torch.cuda.current_device()) + + model = dequantize_and_replace( + model, self.modules_to_not_convert, quantization_config=self.quantization_config + ) + if is_model_on_cpu: + model.to("cpu") + return model + + +class BnB8BitDiffusersQuantizer(DiffusersQuantizer): + """ + 8-bit quantization from bitsandbytes quantization method: + before loading: converts transformer layers into Linear8bitLt during loading: load 16bit weight and pass to the + layer object after: quantizes individual weights in Linear8bitLt into 8bit at fitst .cuda() call + saving: + from state dict, as usual; saves weights and 'SCB' component + loading: + need to locate SCB component and pass to the Linear8bitLt object + """ + + use_keep_in_fp32_modules = True + requires_calibration = False + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + if self.quantization_config.llm_int8_skip_modules is not None: + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + def validate_environment(self, *args, **kwargs): + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): + raise ImportError( + "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" + ) + if not is_bitsandbytes_available() or is_bitsandbytes_version("<", "0.43.3"): + raise ImportError( + "Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" + ) + + if kwargs.get("from_flax", False): + raise ValueError( + "Converting into 8-bit weights from flax weights is currently not supported, please make" + " sure the weights are in PyTorch format." + ) + + device_map = kwargs.get("device_map", None) + if ( + device_map is not None + and isinstance(device_map, dict) + and not self.quantization_config.llm_int8_enable_fp32_cpu_offload + ): + device_map_without_no_convert = { + key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert + } + if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): + raise ValueError( + "Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " + "quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " + "in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom `device_map` to " + "`from_pretrained`. Check " + "https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu " + "for more details. " + ) + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory + def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: + # need more space for buffers that are created during quantization + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + return max_memory + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_torch_dtype + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + logger.info( + "Overriding torch_dtype=%s with `torch_dtype=torch.float16` due to " + "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. " + "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass" + " torch_dtype=torch.float16 to remove this warning.", + torch_dtype, + ) + torch_dtype = torch.float16 + return torch_dtype + + # # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map + # def update_device_map(self, device_map): + # if device_map is None: + # device_map = {"": torch.cuda.current_device()} + # logger.info( + # "The device_map was not initialized. " + # "Setting device_map to {'':torch.cuda.current_device()}. " + # "If you want to use the model for inference, please set device_map ='auto' " + # ) + # return device_map + + def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": + if target_dtype != torch.int8: + logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization") + return torch.int8 + + def check_if_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + state_dict: Dict[str, Any], + **kwargs, + ): + import bitsandbytes as bnb + + module, tensor_name = get_module_from_name(model, param_name) + if isinstance(module._parameters.get(tensor_name, None), bnb.nn.Int8Params): + if self.pre_quantized: + if param_name.replace("weight", "SCB") not in state_dict.keys(): + raise ValueError("Missing quantization component `SCB`") + if param_value.dtype != torch.int8: + raise ValueError( + f"Incompatible dtype `{param_value.dtype}` when loading 8-bit prequantized weight. Expected `torch.int8`." + ) + return True + return False + + def create_quantized_param( + self, + model: "ModelMixin", + param_value: "torch.Tensor", + param_name: str, + target_device: "torch.device", + state_dict: Dict[str, Any], + unexpected_keys: Optional[List[str]] = None, + ): + import bitsandbytes as bnb + + fp16_statistics_key = param_name.replace("weight", "SCB") + fp16_weights_format_key = param_name.replace("weight", "weight_format") + + fp16_statistics = state_dict.get(fp16_statistics_key, None) + fp16_weights_format = state_dict.get(fp16_weights_format_key, None) + + module, tensor_name = get_module_from_name(model, param_name) + if tensor_name not in module._parameters: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + + old_value = getattr(module, tensor_name) + + if not isinstance(module._parameters[tensor_name], bnb.nn.Int8Params): + raise ValueError(f"Parameter `{tensor_name}` should only be a `bnb.nn.Int8Params` instance.") + if ( + old_value.device == torch.device("meta") + and target_device not in ["meta", torch.device("meta")] + and param_value is None + ): + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.") + + new_value = param_value.to("cpu") + if self.pre_quantized and not self.is_serializable: + raise ValueError( + "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + + kwargs = old_value.__dict__ + new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(target_device) + + module._parameters[tensor_name] = new_value + if fp16_statistics is not None: + setattr(module.weight, "SCB", fp16_statistics.to(target_device)) + if unexpected_keys is not None: + unexpected_keys.remove(fp16_statistics_key) + + # We just need to pop the `weight_format` keys from the state dict to remove unneeded + # messages. The correct format is correctly retrieved during the first forward pass. + if fp16_weights_format is not None and unexpected_keys is not None: + unexpected_keys.remove(fp16_weights_format_key) + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit + def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): + model.is_loaded_in_8bit = True + model.is_8bit_serializable = self.is_serializable + return model + + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading + def _process_model_before_weight_loading( + self, + model: "ModelMixin", + device_map, + keep_in_fp32_modules: List[str] = [], + **kwargs, + ): + from .utils import replace_with_bnb_linear + + load_in_8bit_fp32_cpu_offload = self.quantization_config.llm_int8_enable_fp32_cpu_offload + + # We may keep some modules such as the `proj_out` in their original dtype for numerical stability reasons + self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules + + if not isinstance(self.modules_to_not_convert, list): + self.modules_to_not_convert = [self.modules_to_not_convert] + + self.modules_to_not_convert.extend(keep_in_fp32_modules) + + # Extend `self.modules_to_not_convert` to keys that are supposed to be offloaded to `cpu` or `disk` + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + + if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload: + raise ValueError( + "If you want to offload some keys to `cpu` or `disk`, you need to set " + "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be " + " converted to 8-bit but kept in 32-bit." + ) + self.modules_to_not_convert.extend(keys_on_cpu) + + # Purge `None`. + # Unlike `transformers`, we don't know if we should always keep certain modules in FP32 + # in case of diffusion transformer models. For language models and others alike, `lm_head` + # and tied modules are usually kept in FP32. + self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] + + model = replace_with_bnb_linear( + model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config + ) + model.config.quantization_config = self.quantization_config + + @property + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable + def is_serializable(self): + # Because we're mandating `bitsandbytes` 0.43.3. + return True + + @property + # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable + def is_trainable(self) -> bool: + # Because we're mandating `bitsandbytes` 0.43.3. + return True + + def _dequantize(self, model): + from .utils import dequantize_and_replace + + model = dequantize_and_replace( + model, self.modules_to_not_convert, quantization_config=self.quantization_config + ) + return model diff --git a/src/diffusers/quantizers/bitsandbytes/utils.py b/src/diffusers/quantizers/bitsandbytes/utils.py new file mode 100644 index 000000000000..03755db3d1ec --- /dev/null +++ b/src/diffusers/quantizers/bitsandbytes/utils.py @@ -0,0 +1,306 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Adapted from +https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py +""" + +import inspect +from inspect import signature +from typing import Union + +from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging +from ..quantization_config import QuantizationMethod + + +if is_torch_available(): + import torch + import torch.nn as nn + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + +if is_accelerate_available(): + import accelerate + from accelerate import init_empty_weights + from accelerate.hooks import add_hook_to_module, remove_hook_from_module + +logger = logging.get_logger(__name__) + + +def _replace_with_bnb_linear( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, +): + """ + Private method that wraps the recursion for module replacement. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + with init_empty_weights(): + in_features = module.in_features + out_features = module.out_features + + if quantization_config.quantization_method() == "llm_int8": + model._modules[name] = bnb.nn.Linear8bitLt( + in_features, + out_features, + module.bias is not None, + has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, + threshold=quantization_config.llm_int8_threshold, + ) + has_been_replaced = True + else: + if ( + quantization_config.llm_int8_skip_modules is not None + and name in quantization_config.llm_int8_skip_modules + ): + pass + else: + extra_kwargs = ( + {"quant_storage": quantization_config.bnb_4bit_quant_storage} + if "quant_storage" in list(signature(bnb.nn.Linear4bit).parameters) + else {} + ) + model._modules[name] = bnb.nn.Linear4bit( + in_features, + out_features, + module.bias is not None, + quantization_config.bnb_4bit_compute_dtype, + compress_statistics=quantization_config.bnb_4bit_use_double_quant, + quant_type=quantization_config.bnb_4bit_quant_type, + **extra_kwargs, + ) + has_been_replaced = True + # Store the module class in case we need to transpose the weight later + model._modules[name].source_cls = type(module) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + if len(list(module.children())) > 0: + _, has_been_replaced = _replace_with_bnb_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): + """ + Helper function to replace the `nn.Linear` layers within `model` with either `bnb.nn.Linear8bit` or + `bnb.nn.Linear4bit` using the `bitsandbytes` library. + + References: + * `bnb.nn.Linear8bit`: [LLM.int8(): 8-bit Matrix Multiplication for Transformers at + Scale](https://arxiv.org/abs/2208.07339) + * `bnb.nn.Linear4bit`: [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + modules_to_not_convert (`List[`str`]`, *optional*, defaults to `[]`): + Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `modules_to_not_convert` in + full precision for numerical stability reasons. + current_key_name (`List[`str`]`, *optional*): + An array to track the current key of the recursion. This is used to check whether the current key (part of + it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or + `disk`). + quantization_config ('transformers.utils.quantization_config.BitsAndBytesConfig'): + To configure and manage settings related to quantization, a technique used to compress neural network + models by reducing the precision of the weights and activations, thus making models more efficient in terms + of both storage and computation. + """ + model, has_been_replaced = _replace_with_bnb_linear( + model, modules_to_not_convert, current_key_name, quantization_config + ) + + if not has_been_replaced: + logger.warning( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + return model + + +# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41 +def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): + """ + Helper function to dequantize 4bit or 8bit bnb weights. + + If the weight is not a bnb quantized weight, it will be returned as is. + """ + if not isinstance(weight, torch.nn.Parameter): + raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead") + + cls_name = weight.__class__.__name__ + if cls_name not in ("Params4bit", "Int8Params"): + return weight + + if cls_name == "Params4bit": + output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) + logger.warning_once( + f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`" + ) + return output_tensor + + if state.SCB is None: + state.SCB = weight.SCB + + im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device) + im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im) + im, Sim = bnb.functional.transform(im, "col32") + if state.CxB is None: + state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB) + out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB) + return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() + + +def _create_accelerate_new_hook(old_hook): + r""" + Creates a new hook based on the old hook. Use it only if you know what you are doing ! This method is a copy of: + https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 with + some changes + """ + old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) + old_hook_attr = old_hook.__dict__ + filtered_old_hook_attr = {} + old_hook_init_signature = inspect.signature(old_hook_cls.__init__) + for k in old_hook_attr.keys(): + if k in old_hook_init_signature.parameters: + filtered_old_hook_attr[k] = old_hook_attr[k] + new_hook = old_hook_cls(**filtered_old_hook_attr) + return new_hook + + +def _dequantize_and_replace( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, +): + """ + Converts a quantized model into its dequantized original version. The newly converted model will have some + performance drop compared to the original model before quantization - use it only for specific usecases such as + QLoRA adapters merging. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + quant_method = quantization_config.quantization_method() + + target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit + + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, target_cls) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + bias = getattr(module, "bias", None) + + device = module.weight.device + with init_empty_weights(): + new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None) + + if quant_method == "llm_int8": + state = module.state + else: + state = None + + new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state)) + + if bias is not None: + new_module.bias = bias + + # Create a new hook and attach it in case we use accelerate + if hasattr(module, "_hf_hook"): + old_hook = module._hf_hook + new_hook = _create_accelerate_new_hook(old_hook) + + remove_hook_from_module(module) + add_hook_to_module(new_module, new_hook) + + new_module.to(device) + model._modules[name] = new_module + has_been_replaced = True + if len(list(module.children())) > 0: + _, has_been_replaced = _dequantize_and_replace( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def dequantize_and_replace( + model, + modules_to_not_convert=None, + quantization_config=None, +): + model, has_been_replaced = _dequantize_and_replace( + model, + modules_to_not_convert=modules_to_not_convert, + quantization_config=quantization_config, + ) + + if not has_been_replaced: + logger.warning( + "For some reason the model has not been properly dequantized. You might see unexpected behavior." + ) + + return model + + +def _check_bnb_status(module) -> Union[bool, bool]: + is_loaded_in_4bit_bnb = ( + hasattr(module, "is_loaded_in_4bit") + and module.is_loaded_in_4bit + and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ) + is_loaded_in_8bit_bnb = ( + hasattr(module, "is_loaded_in_8bit") + and module.is_loaded_in_8bit + and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ) + return is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py new file mode 100644 index 000000000000..f521c5d717d6 --- /dev/null +++ b/src/diffusers/quantizers/quantization_config.py @@ -0,0 +1,391 @@ +#!/usr/bin/env python +# coding=utf-8 + +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Adapted from +https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e5411a8026/src/transformers/utils/quantization_config.py +""" + +import copy +import importlib.metadata +import json +import os +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, Union + +from packaging import version + +from ..utils import is_torch_available, logging + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class QuantizationMethod(str, Enum): + BITS_AND_BYTES = "bitsandbytes" + + +@dataclass +class QuantizationConfigMixin: + """ + Mixin class for quantization config + """ + + quant_method: QuantizationMethod + _exclude_attributes_at_init = [] + + @classmethod + def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): + """ + Instantiates a [`QuantizationConfigMixin`] from a Python dictionary of parameters. + + Args: + config_dict (`Dict[str, Any]`): + Dictionary that will be used to instantiate the configuration object. + return_unused_kwargs (`bool`,*optional*, defaults to `False`): + Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in + `PreTrainedModel`. + kwargs (`Dict[str, Any]`): + Additional parameters from which to initialize the configuration object. + + Returns: + [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters. + """ + + config = cls(**config_dict) + + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + if return_unused_kwargs: + return config, kwargs + else: + return config + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) + + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self, use_diff: bool = True) -> str: + """ + Serializes this instance to a JSON string. + + Args: + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` + is serialized to JSON string. + + Returns: + `str`: String containing all the attributes that make up this configuration instance in JSON format. + """ + if use_diff is True: + config_dict = self.to_diff_dict() + else: + config_dict = self.to_dict() + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, + returning all the unused kwargs. + + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + + +@dataclass +class BitsAndBytesConfig(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using `bitsandbytes`. + + This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive. + + Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`, + then more arguments will be added to this class. + + Args: + load_in_8bit (`bool`, *optional*, defaults to `False`): + This flag is used to enable 8-bit quantization with LLM.int8(). + load_in_4bit (`bool`, *optional*, defaults to `False`): + This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from + `bitsandbytes`. + llm_int8_threshold (`float`, *optional*, defaults to 6.0): + This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix + Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value + that is above this threshold will be considered an outlier and the operation on those values will be done + in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but + there are some exceptional systematic outliers that are very differently distributed for large models. + These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of + magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, + but a lower threshold might be needed for more unstable models (small models, fine-tuning). + llm_int8_skip_modules (`List[str]`, *optional*): + An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as + Jukebox that has several heads in different places and not necessarily at the last position. For example + for `CausalLM` models, the last `lm_head` is typically kept in its original `dtype`. + llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`): + This flag is used for advanced use cases and users that are aware of this feature. If you want to split + your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use + this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8 + operations will not be run on CPU. + llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`): + This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not + have to be converted back and forth for the backward pass. + bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`): + This sets the computational type which might be different than the input type. For example, inputs might be + fp32, but computation can be set to bf16 for speedups. + bnb_4bit_quant_type (`str`, *optional*, defaults to `"fp4"`): + This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types + which are specified by `fp4` or `nf4`. + bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`): + This flag is used for nested quantization where the quantization constants from the first quantization are + quantized again. + bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`): + This sets the storage type to pack the quanitzed 4-bit prarams. + kwargs (`Dict[str, Any]`, *optional*): + Additional parameters from which to initialize the configuration object. + """ + + _exclude_attributes_at_init = ["_load_in_4bit", "_load_in_8bit", "quant_method"] + + def __init__( + self, + load_in_8bit=False, + load_in_4bit=False, + llm_int8_threshold=6.0, + llm_int8_skip_modules=None, + llm_int8_enable_fp32_cpu_offload=False, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=None, + bnb_4bit_quant_type="fp4", + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_storage=None, + **kwargs, + ): + self.quant_method = QuantizationMethod.BITS_AND_BYTES + + if load_in_4bit and load_in_8bit: + raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") + + self._load_in_8bit = load_in_8bit + self._load_in_4bit = load_in_4bit + self.llm_int8_threshold = llm_int8_threshold + self.llm_int8_skip_modules = llm_int8_skip_modules + self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload + self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight + self.bnb_4bit_quant_type = bnb_4bit_quant_type + self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant + + if bnb_4bit_compute_dtype is None: + self.bnb_4bit_compute_dtype = torch.float32 + elif isinstance(bnb_4bit_compute_dtype, str): + self.bnb_4bit_compute_dtype = getattr(torch, bnb_4bit_compute_dtype) + elif isinstance(bnb_4bit_compute_dtype, torch.dtype): + self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype + else: + raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype") + + if bnb_4bit_quant_storage is None: + self.bnb_4bit_quant_storage = torch.uint8 + elif isinstance(bnb_4bit_quant_storage, str): + if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]: + raise ValueError( + "`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') " + ) + self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage) + elif isinstance(bnb_4bit_quant_storage, torch.dtype): + self.bnb_4bit_quant_storage = bnb_4bit_quant_storage + else: + raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype") + + if kwargs and not all(k in self._exclude_attributes_at_init for k in kwargs): + logger.warning(f"Unused kwargs: {list(kwargs.keys())}. These kwargs are not used in {self.__class__}.") + + self.post_init() + + @property + def load_in_4bit(self): + return self._load_in_4bit + + @load_in_4bit.setter + def load_in_4bit(self, value: bool): + if not isinstance(value, bool): + raise TypeError("load_in_4bit must be a boolean") + + if self.load_in_8bit and value: + raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") + self._load_in_4bit = value + + @property + def load_in_8bit(self): + return self._load_in_8bit + + @load_in_8bit.setter + def load_in_8bit(self, value: bool): + if not isinstance(value, bool): + raise TypeError("load_in_8bit must be a boolean") + + if self.load_in_4bit and value: + raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") + self._load_in_8bit = value + + def post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + if not isinstance(self.load_in_4bit, bool): + raise TypeError("load_in_4bit must be a boolean") + + if not isinstance(self.load_in_8bit, bool): + raise TypeError("load_in_8bit must be a boolean") + + if not isinstance(self.llm_int8_threshold, float): + raise TypeError("llm_int8_threshold must be a float") + + if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list): + raise TypeError("llm_int8_skip_modules must be a list of strings") + if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool): + raise TypeError("llm_int8_enable_fp32_cpu_offload must be a boolean") + + if not isinstance(self.llm_int8_has_fp16_weight, bool): + raise TypeError("llm_int8_has_fp16_weight must be a boolean") + + if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype): + raise TypeError("bnb_4bit_compute_dtype must be torch.dtype") + + if not isinstance(self.bnb_4bit_quant_type, str): + raise TypeError("bnb_4bit_quant_type must be a string") + + if not isinstance(self.bnb_4bit_use_double_quant, bool): + raise TypeError("bnb_4bit_use_double_quant must be a boolean") + + if self.load_in_4bit and not version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse( + "0.39.0" + ): + raise ValueError( + "4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version" + ) + + def is_quantizable(self): + r""" + Returns `True` if the model is quantizable, `False` otherwise. + """ + return self.load_in_8bit or self.load_in_4bit + + def quantization_method(self): + r""" + This method returns the quantization method used for the model. If the model is not quantizable, it returns + `None`. + """ + if self.load_in_8bit: + return "llm_int8" + elif self.load_in_4bit and self.bnb_4bit_quant_type == "fp4": + return "fp4" + elif self.load_in_4bit and self.bnb_4bit_quant_type == "nf4": + return "nf4" + else: + return None + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1] + output["bnb_4bit_quant_storage"] = str(output["bnb_4bit_quant_storage"]).split(".")[1] + output["load_in_4bit"] = self.load_in_4bit + output["load_in_8bit"] = self.load_in_8bit + + return output + + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" + + def to_diff_dict(self) -> Dict[str, Any]: + """ + Removes all attributes from config which correspond to the default config attributes for better readability and + serializes to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = BitsAndBytesConfig().to_dict() + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 3329919cfb02..6841a34a6489 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -330,6 +330,7 @@ def set_timesteps( # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + clipped_idx = clipped_idx.item() timesteps = ( np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) .round()[::-1][:-1] diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index c7ea2bcc5b7f..c8f64adf3e8a 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -62,6 +62,7 @@ is_accelerate_available, is_accelerate_version, is_bitsandbytes_available, + is_bitsandbytes_version, is_bs4_available, is_flax_available, is_ftfy_available, @@ -94,7 +95,7 @@ is_xformers_available, requires_backends, ) -from .loading_utils import load_image, load_video +from .loading_utils import get_module_from_name, load_image, load_video from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index eaab67c93b18..10d0399a6761 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1020,6 +1020,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DiffusersQuantizer(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AmusedScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index d8109eee6d35..9046a4f73533 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -272,6 +272,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class CogVideoXFunControlPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class CogVideoXImageToVideoPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 3ff859f17fe3..448e92509732 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -457,7 +457,7 @@ def _get_checkpoint_shard_files( ignore_patterns = ["*.json", "*.md"] if not local_files_only: # `model_info` call must guarded with the above condition. - model_files_info = model_info(pretrained_model_name_or_path, revision=revision) + model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token) for shard_file in original_shard_filenames: shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings) if not shard_file_present: diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index daecec4aa258..f1323bf00ea4 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -745,6 +745,20 @@ def is_peft_version(operation: str, version: str): return compare_versions(parse(_peft_version), operation, version) +def is_bitsandbytes_version(operation: str, version: str): + """ + Args: + Compares the current bitsandbytes version to a given reference with an operation. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _bitsandbytes_version: + return False + return compare_versions(parse(_bitsandbytes_version), operation, version) + + def is_k_diffusion_version(operation: str, version: str): """ Compares the current k-diffusion version to a given reference with an operation. diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py index b36664cb81ff..bac24fa23e63 100644 --- a/src/diffusers/utils/loading_utils.py +++ b/src/diffusers/utils/loading_utils.py @@ -1,6 +1,6 @@ import os import tempfile -from typing import Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Tuple, Union from urllib.parse import unquote, urlparse import PIL.Image @@ -135,3 +135,16 @@ def load_video( pil_images = convert_method(pil_images) return pil_images + + +# Taken from `transformers`. +def get_module_from_name(module, tensor_name: str) -> Tuple[Any, str]: + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + return module, tensor_name diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index ca55192ff7ae..dcc78a547a13 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -134,14 +134,14 @@ def unscale_lora_layers(model, weight: Optional[float] = None): """ from peft.tuners.tuners_utils import BaseTunerLayer - if weight == 1.0: + if weight is None or weight == 1.0: return for module in model.modules(): if isinstance(module, BaseTunerLayer): - if weight is not None and weight != 0: + if weight != 0: module.unscale_layer(weight) - elif weight is not None and weight == 0: + else: for adapter_name in module.active_adapters: # if weight == 0 unscale should re-set the scale to the original value. module.set_scale(adapter_name, 1.0) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index a2f283d0c4f5..6361cca663b9 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1,5 +1,6 @@ import functools import importlib +import importlib.metadata import inspect import io import logging @@ -27,6 +28,8 @@ from .import_utils import ( BACKENDS_MAPPING, + is_accelerate_available, + is_bitsandbytes_available, is_compel_available, is_flax_available, is_note_seq_available, @@ -371,6 +374,20 @@ def require_timm(test_case): return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case) +def require_bitsandbytes(test_case): + """ + Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed. + """ + return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case) + + +def require_accelerate(test_case): + """ + Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. + """ + return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case) + + def require_peft_version_greater(peft_version): """ Decorator marking a test that requires PEFT backend with a specific version, this would require some specific @@ -408,7 +425,7 @@ def decorator(test_case): def require_accelerate_version_greater(accelerate_version): def decorator(test_case): - correct_accelerate_version = is_peft_available() and version.parse( + correct_accelerate_version = is_accelerate_available() and version.parse( version.parse(importlib.metadata.version("accelerate")).base_version ) > version.parse(accelerate_version) return unittest.skipUnless( @@ -418,6 +435,18 @@ def decorator(test_case): return decorator +def require_bitsandbytes_version_greater(bnb_version): + def decorator(test_case): + correct_bnb_version = is_bitsandbytes_available() and version.parse( + version.parse(importlib.metadata.version("bitsandbytes")).base_version + ) > version.parse(bnb_version) + return unittest.skipUnless( + correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}." + )(test_case) + + return decorator + + def deprecate_after_peft_backend(test_case): """ Decorator marking a test that will be skipped after PEFT backend diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py index 2b9084327289..2be4744c5ac4 100644 --- a/tests/models/transformers/test_models_transformer_sd3.py +++ b/tests/models/transformers/test_models_transformer_sd3.py @@ -73,6 +73,65 @@ def prepare_init_args_and_inputs_for_common(self): "joint_attention_dim": 32, "pooled_projection_dim": 64, "out_channels": 4, + "pos_embed_max_size": 96, + "dual_attention_layers": (), + "qk_norm": None, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + @unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply") + def test_set_attn_processor_for_determinism(self): + pass + + +class SD35TransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = SD3Transformer2DModel + main_input_name = "hidden_states" + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + height = width = embedding_dim = 32 + pooled_embedding_dim = embedding_dim * 2 + sequence_length = 154 + + hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_prompt_embeds, + "timestep": timestep, + } + + @property + def input_shape(self): + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "sample_size": 32, + "patch_size": 1, + "in_channels": 4, + "num_layers": 2, + "attention_head_dim": 8, + "num_attention_heads": 4, + "caption_projection_dim": 32, + "joint_attention_dim": 32, + "pooled_projection_dim": 64, + "out_channels": 4, + "pos_embed_max_size": 96, + "dual_attention_layers": (0,), + "qk_norm": "rms_norm", } inputs_dict = self.dummy_input return init_dict, inputs_dict diff --git a/tests/pipelines/cogvideo/test_cogvideox_fun_control.py b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py new file mode 100644 index 000000000000..2a51fc65798c --- /dev/null +++ b/tests/pipelines/cogvideo/test_cogvideox_fun_control.py @@ -0,0 +1,324 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLCogVideoX, CogVideoXFunControlPipeline, CogVideoXTransformer3DModel, DDIMScheduler +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import ( + PipelineTesterMixin, + check_qkv_fusion_matches_attn_procs_length, + check_qkv_fusion_processors_exist, + to_np, +) + + +enable_full_determinism() + + +class CogVideoXFunControlPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = CogVideoXFunControlPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"control_video"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = CogVideoXTransformer3DModel( + # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings + # But, since we are using tiny-random-t5 here, we need the internal dim of CogVideoXTransformer3DModel + # to be 32. The internal dim is product of num_attention_heads and attention_head_dim + num_attention_heads=4, + attention_head_dim=8, + in_channels=8, + out_channels=4, + time_embed_dim=2, + text_embed_dim=32, # Must match with tiny-random-t5 + num_layers=1, + sample_width=2, # latent width: 2 -> final width: 16 + sample_height=2, # latent height: 2 -> final height: 16 + sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9 + patch_size=2, + temporal_compression_ratio=4, + max_text_seq_length=16, + ) + + torch.manual_seed(0) + vae = AutoencoderKLCogVideoX( + in_channels=3, + out_channels=3, + down_block_types=( + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + "CogVideoXDownBlock3D", + ), + up_block_types=( + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + "CogVideoXUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + latent_channels=4, + layers_per_block=1, + norm_num_groups=2, + temporal_compression_ratio=4, + ) + + torch.manual_seed(0) + scheduler = DDIMScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed: int = 0, num_frames: int = 8): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + # Cannot reduce because convolution kernel becomes bigger than sample + height = 16 + width = 16 + + control_video = [Image.new("RGB", (width, height))] * num_frames + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "", + "control_video": control_video, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": height, + "width": width, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (8, 3, 16, 16)) + expected_video = torch.randn(8, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.5): + # NOTE(aryan): This requires a higher expected_max_diff than other CogVideoX pipelines + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_overlap_factor_height=1 / 12, + tile_overlap_factor_width=1 / 12, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + def test_fused_qkv_projections(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames # [B, F, C, H, W] + original_image_slice = frames[0, -2:, -1, -3:, -3:] + + pipe.fuse_qkv_projections() + assert check_qkv_fusion_processors_exist( + pipe.transformer + ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused." + assert check_qkv_fusion_matches_attn_procs_length( + pipe.transformer, pipe.transformer.original_attn_processors + ), "Something wrong with the attention processors concerning the fused QKV projections." + + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames + image_slice_fused = frames[0, -2:, -1, -3:, -3:] + + pipe.transformer.unfuse_qkv_projections() + inputs = self.get_dummy_inputs(device) + frames = pipe(**inputs).frames + image_slice_disabled = frames[0, -2:, -1, -3:, -3:] + + assert np.allclose( + original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3 + ), "Fusion of QKV projections shouldn't affect the outputs." + assert np.allclose( + image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3 + ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." + assert np.allclose( + original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 + ), "Original outputs should match when fused QKV projections are disabled." diff --git a/tests/pipelines/cogvideo/test_cogvideox_image2video.py b/tests/pipelines/cogvideo/test_cogvideox_image2video.py index ec9a5fdd153e..f7e1fe7fd6c7 100644 --- a/tests/pipelines/cogvideo/test_cogvideox_image2video.py +++ b/tests/pipelines/cogvideo/test_cogvideox_image2video.py @@ -43,7 +43,7 @@ enable_full_determinism() -class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = CogVideoXImageToVideoPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"}) @@ -343,7 +343,6 @@ def test_fused_qkv_projections(self): ), "Original outputs should match when fused QKV projections are disabled." -@unittest.skip("The model 'THUDM/CogVideoX-5b-I2V' is not public yet.") @slow @require_torch_gpu class CogVideoXImageToVideoPipelineIntegrationTests(unittest.TestCase): diff --git a/tests/pipelines/test_pipeline_utils.py b/tests/pipelines/test_pipeline_utils.py index 697244dcb105..bb3bdc273cc4 100644 --- a/tests/pipelines/test_pipeline_utils.py +++ b/tests/pipelines/test_pipeline_utils.py @@ -18,7 +18,7 @@ StableDiffusionPipeline, UNet2DConditionModel, ) -from diffusers.pipelines.pipeline_utils import is_safetensors_compatible +from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible from diffusers.utils.testing_utils import torch_device @@ -197,6 +197,18 @@ def test_diffusers_is_compatible_only_variants(self): ] self.assertTrue(is_safetensors_compatible(filenames)) + def test_diffusers_is_compatible_no_components(self): + filenames = [ + "diffusion_pytorch_model.bin", + ] + self.assertFalse(is_safetensors_compatible(filenames)) + + def test_diffusers_is_compatible_no_components_only_variants(self): + filenames = [ + "diffusion_pytorch_model.fp16.bin", + ] + self.assertFalse(is_safetensors_compatible(filenames)) + class ProgressBarTests(unittest.TestCase): def get_dummy_components_image_generation(self): diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 8b087db6726e..43b01c40f5bb 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -947,6 +947,27 @@ def test_text_inversion_multi_tokens(self): emb1[num_tokens + 1].sum().item() == emb2[num_tokens + 1].sum().item() == emb3[num_tokens + 1].sum().item() ) + def test_textual_inversion_unload(self): + pipe1 = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None + ) + pipe1 = pipe1.to(torch_device) + orig_tokenizer_size = len(pipe1.tokenizer) + orig_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight) + + token = "<*>" + ten = torch.ones((32,)) + pipe1.load_textual_inversion(ten, token=token) + pipe1.unload_textual_inversion() + pipe1.load_textual_inversion(ten, token=token) + pipe1.unload_textual_inversion() + + final_tokenizer_size = len(pipe1.tokenizer) + final_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight) + # both should be restored to original size + assert final_tokenizer_size == orig_tokenizer_size + assert final_emb_size == orig_emb_size + def test_download_ignore_files(self): # Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4 with tempfile.TemporaryDirectory() as tmpdirname: diff --git a/tests/quantization/bnb/README.md b/tests/quantization/bnb/README.md new file mode 100644 index 000000000000..f1585581597d --- /dev/null +++ b/tests/quantization/bnb/README.md @@ -0,0 +1,44 @@ +The tests here are adapted from [`transformers` tests](https://github.com/huggingface/transformers/tree/409fcfdfccde77a14b7cc36972b774cabc371ae1/tests/quantization/bnb). + +They were conducted on the `audace` machine, using a single RTX 4090. Below is `nvidia-smi`: + +```bash ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 550.90.07 Driver Version: 550.90.07 CUDA Version: 12.4 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 4090 Off | 00000000:01:00.0 Off | Off | +| 30% 55C P0 61W / 450W | 1MiB / 24564MiB | 2% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA GeForce RTX 4090 Off | 00000000:13:00.0 Off | Off | +| 30% 51C P0 60W / 450W | 1MiB / 24564MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ +``` + +`diffusers-cli`: + +```bash +- 🤗 Diffusers version: 0.31.0.dev0 +- Platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.35 +- Running on Google Colab?: No +- Python version: 3.10.12 +- PyTorch version (GPU?): 2.5.0.dev20240818+cu124 (True) +- Flax version (CPU?/GPU?/TPU?): not installed (NA) +- Jax version: not installed +- JaxLib version: not installed +- Huggingface_hub version: 0.24.5 +- Transformers version: 4.44.2 +- Accelerate version: 0.34.0.dev0 +- PEFT version: 0.12.0 +- Bitsandbytes version: 0.43.3 +- Safetensors version: 0.4.4 +- xFormers version: not installed +- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB +NVIDIA GeForce RTX 4090, 24564 MiB +- Using GPU in script?: Yes +``` \ No newline at end of file diff --git a/tests/quantization/bnb/__init__.py b/tests/quantization/bnb/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py new file mode 100644 index 000000000000..7b553434fbe9 --- /dev/null +++ b/tests/quantization/bnb/test_4bit.py @@ -0,0 +1,627 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import os +import tempfile +import unittest + +import numpy as np +import safetensors.torch + +from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel +from diffusers.utils import logging +from diffusers.utils.testing_utils import ( + CaptureLogger, + is_bitsandbytes_available, + is_torch_available, + is_transformers_available, + load_pt, + numpy_cosine_similarity_distance, + require_accelerate, + require_bitsandbytes_version_greater, + require_torch, + require_torch_gpu, + require_transformers_version_greater, + slow, + torch_device, +) + + +def get_some_linear_layer(model): + if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: + return model.transformer_blocks[0].attn.to_q + else: + return NotImplementedError("Don't know what layer to retrieve here.") + + +if is_transformers_available(): + from transformers import T5EncoderModel + +if is_torch_available(): + import torch + import torch.nn as nn + + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77 + """ + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) + + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + + +@require_bitsandbytes_version_greater("0.43.2") +@require_accelerate +@require_torch +@require_torch_gpu +@slow +class Base4bitTests(unittest.TestCase): + # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) + # Therefore here we use only SD3 to test our module + model_name = "stabilityai/stable-diffusion-3-medium-diffusers" + + # This was obtained on audace so the number might slightly change + expected_rel_difference = 3.69 + + prompt = "a beautiful sunset amidst the mountains." + num_inference_steps = 10 + seed = 0 + + def get_dummy_inputs(self): + prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt" + ) + pooled_prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt" + ) + latent_model_input = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt" + ) + + input_dict_for_transformer = { + "hidden_states": latent_model_input, + "encoder_hidden_states": prompt_embeds, + "pooled_projections": pooled_prompt_embeds, + "timestep": torch.Tensor([1.0]), + "return_dict": False, + } + return input_dict_for_transformer + + +class BnB4BitBasicTests(Base4bitTests): + def setUp(self): + gc.collect() + torch.cuda.empty_cache() + + # Models + self.model_fp16 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", torch_dtype=torch.float16 + ) + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + self.model_4bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + + def tearDown(self): + del self.model_fp16 + del self.model_4bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quantization_num_parameters(self): + r""" + Test if the number of returned parameters is correct + """ + num_params_4bit = self.model_4bit.num_parameters() + num_params_fp16 = self.model_fp16.num_parameters() + + self.assertEqual(num_params_4bit, num_params_fp16) + + def test_quantization_config_json_serialization(self): + r""" + A simple test to check if the quantization config is correctly serialized and deserialized + """ + config = self.model_4bit.config + + self.assertTrue("quantization_config" in config) + + _ = config["quantization_config"].to_dict() + _ = config["quantization_config"].to_diff_dict() + + _ = config["quantization_config"].to_json_string() + + def test_memory_footprint(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + mem_fp16 = self.model_fp16.get_memory_footprint() + mem_4bit = self.model_4bit.get_memory_footprint() + + self.assertAlmostEqual(mem_fp16 / mem_4bit, self.expected_rel_difference, delta=1e-2) + linear = get_some_linear_layer(self.model_4bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) + + def test_original_dtype(self): + r""" + A simple test to check if the model succesfully stores the original dtype + """ + self.assertTrue("_pre_quantization_dtype" in self.model_4bit.config) + self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config) + self.assertTrue(self.model_4bit.config["_pre_quantization_dtype"] == torch.float16) + + def test_keep_modules_in_fp32(self): + r""" + A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. + Also ensures if inference works. + """ + fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules + SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"] + + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + model = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if name in model._keep_in_fp32_modules: + self.assertTrue(module.weight.dtype == torch.float32) + else: + # 4-bit parameters are packed in uint8 variables + self.assertTrue(module.weight.dtype == torch.uint8) + + # test if inference works. + with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + _ = model(**model_inputs) + + SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules + + def test_linear_are_4bit(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + self.model_fp16.get_memory_footprint() + self.model_4bit.get_memory_footprint() + + for name, module in self.model_4bit.named_modules(): + if isinstance(module, torch.nn.Linear): + if name not in ["proj_out"]: + # 4-bit parameters are packed in uint8 variables + self.assertTrue(module.weight.dtype == torch.uint8) + + def test_config_from_pretrained(self): + transformer_4bit = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer" + ) + linear = get_some_linear_layer(transformer_4bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) + self.assertTrue(hasattr(linear.weight, "quant_state")) + self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState) + + def test_device_assignment(self): + mem_before = self.model_4bit.get_memory_footprint() + + # Move to CPU + self.model_4bit.to("cpu") + self.assertEqual(self.model_4bit.device.type, "cpu") + self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) + + # Move back to CUDA device + for device in [0, "cuda", "cuda:0", "call()"]: + if device == "call()": + self.model_4bit.cuda(0) + else: + self.model_4bit.to(device) + self.assertEqual(self.model_4bit.device, torch.device(0)) + self.assertAlmostEqual(self.model_4bit.get_memory_footprint(), mem_before) + self.model_4bit.to("cpu") + + def test_device_and_dtype_assignment(self): + r""" + Test whether trying to cast (or assigning a device to) a model after converting it in 4-bit will throw an error. + Checks also if other models are casted correctly. Device placement, however, is supported. + """ + with self.assertRaises(ValueError): + # Tries with a `dtype` + self.model_4bit.to(torch.float16) + + with self.assertRaises(ValueError): + # Tries with a `device` and `dtype` + self.model_4bit.to(device="cuda:0", dtype=torch.float16) + + with self.assertRaises(ValueError): + # Tries with a cast + self.model_4bit.float() + + with self.assertRaises(ValueError): + # Tries with a cast + self.model_4bit.half() + + # This should work + self.model_4bit.to("cuda") + + # Test if we did not break anything + self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device) + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(dtype=torch.float32, device=torch_device) + for k, v in input_dict_for_transformer.items() + if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + with torch.no_grad(): + _ = self.model_fp16(**model_inputs) + + # Check this does not throw an error + _ = self.model_fp16.to("cpu") + + # Check this does not throw an error + _ = self.model_fp16.half() + + # Check this does not throw an error + _ = self.model_fp16.float() + + # Check that this does not throw an error + _ = self.model_fp16.cuda() + + def test_bnb_4bit_wrong_config(self): + r""" + Test whether creating a bnb config with unsupported values leads to errors. + """ + with self.assertRaises(ValueError): + _ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add") + + def test_bnb_4bit_errors_loading_incorrect_state_dict(self): + r""" + Test if loading with an incorrect state dict raises an error. + """ + with tempfile.TemporaryDirectory() as tmpdirname: + nf4_config = BitsAndBytesConfig(load_in_4bit=True) + model_4bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + model_4bit.save_pretrained(tmpdirname) + del model_4bit + + with self.assertRaises(ValueError) as err_context: + state_dict = safetensors.torch.load_file( + os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors") + ) + + # corrupt the state dict + key_to_target = "context_embedder.weight" # can be other keys too. + compatible_param = state_dict[key_to_target] + corrupted_param = torch.randn(compatible_param.shape[0] - 1, 1) + state_dict[key_to_target] = bnb.nn.Params4bit(corrupted_param, requires_grad=False) + safetensors.torch.save_file( + state_dict, os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors") + ) + + _ = SD3Transformer2DModel.from_pretrained(tmpdirname) + + assert key_to_target in str(err_context.exception) + + +class BnB4BitTrainingTests(Base4bitTests): + def setUp(self): + gc.collect() + torch.cuda.empty_cache() + + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + self.model_4bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + + def test_training(self): + # Step 1: freeze all parameters + for param in self.model_4bit.parameters(): + param.requires_grad = False # freeze the model - train adapters later + if param.ndim == 1: + # cast the small parameters (e.g. layernorm) to fp32 for stability + param.data = param.data.to(torch.float32) + + # Step 2: add adapters + for _, module in self.model_4bit.named_modules(): + if "Attention" in repr(type(module)): + module.to_k = LoRALayer(module.to_k, rank=4) + module.to_q = LoRALayer(module.to_q, rank=4) + module.to_v = LoRALayer(module.to_v, rank=4) + + # Step 3: dummy batch + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + + # Step 4: Check if the gradient is not None + with torch.amp.autocast("cuda", dtype=torch.float16): + out = self.model_4bit(**model_inputs)[0] + out.norm().backward() + + for module in self.model_4bit.modules(): + if isinstance(module, LoRALayer): + self.assertTrue(module.adapter[1].weight.grad is not None) + self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) + + +@require_transformers_version_greater("4.44.0") +class SlowBnb4BitTests(Base4bitTests): + def setUp(self) -> None: + gc.collect() + torch.cuda.empty_cache() + + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + model_4bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + self.pipeline_4bit = DiffusionPipeline.from_pretrained( + self.model_name, transformer=model_4bit, torch_dtype=torch.float16 + ) + self.pipeline_4bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_4bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quality(self): + output = self.pipeline_4bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.1123, 0.1296, 0.1609, 0.1042, 0.1230, 0.1274, 0.0928, 0.1165, 0.1216]) + + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + print(f"{max_diff=}") + self.assertTrue(max_diff < 1e-2) + + def test_generate_quality_dequantize(self): + r""" + Test that loading the model and unquantize it produce correct results. + """ + self.pipeline_4bit.transformer.dequantize() + output = self.pipeline_4bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.1216, 0.1387, 0.1584, 0.1152, 0.1318, 0.1282, 0.1062, 0.1226, 0.1228]) + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3) + + # Since we offloaded the `pipeline_4bit.transformer` to CPU (result of `enable_model_cpu_offload()), check + # the following. + self.assertTrue(self.pipeline_4bit.transformer.device.type == "cpu") + # calling it again shouldn't be a problem + _ = self.pipeline_4bit( + prompt=self.prompt, + num_inference_steps=2, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + def test_moving_to_cpu_throws_warning(self): + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + model_4bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=nf4_config + ) + + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel(30) + with CaptureLogger(logger) as cap_logger: + # Because `model.dtype` will return torch.float16 as SD3 transformer has + # a conv layer as the first layer. + _ = DiffusionPipeline.from_pretrained( + self.model_name, transformer=model_4bit, torch_dtype=torch.float16 + ).to("cpu") + + assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out + + +@require_transformers_version_greater("4.44.0") +class SlowBnb4BitFluxTests(Base4bitTests): + def setUp(self) -> None: + gc.collect() + torch.cuda.empty_cache() + + model_id = "hf-internal-testing/flux.1-dev-nf4-pkg" + t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer") + self.pipeline_4bit = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + text_encoder_2=t5_4bit, + transformer=transformer_4bit, + torch_dtype=torch.float16, + ) + self.pipeline_4bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_4bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quality(self): + # keep the resolution and max tokens to a lower number for faster execution. + output = self.pipeline_4bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + height=256, + width=256, + max_sequence_length=64, + output_type="np", + ).images + + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0583, 0.0586, 0.0632, 0.0815, 0.0813, 0.0947, 0.1040, 0.1145, 0.1265]) + + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3) + + +@slow +class BaseBnb4BitSerializationTests(Base4bitTests): + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def test_serialization(self, quant_type="nf4", double_quant=True, safe_serialization=True): + r""" + Test whether it is possible to serialize a model in 4-bit. Uses most typical params as default. + See ExtendedSerializationTest class for more params combinations. + """ + + self.quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type=quant_type, + bnb_4bit_use_double_quant=double_quant, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + model_0 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=self.quantization_config + ) + self.assertTrue("_pre_quantization_dtype" in model_0.config) + with tempfile.TemporaryDirectory() as tmpdirname: + model_0.save_pretrained(tmpdirname, safe_serialization=safe_serialization) + + config = SD3Transformer2DModel.load_config(tmpdirname) + self.assertTrue("quantization_config" in config) + self.assertTrue("_pre_quantization_dtype" not in config) + + model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname) + + # checking quantized linear module weight + linear = get_some_linear_layer(model_1) + self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit) + self.assertTrue(hasattr(linear.weight, "quant_state")) + self.assertTrue(linear.weight.quant_state.__class__ == bnb.functional.QuantState) + + # checking memory footpring + self.assertAlmostEqual(model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2) + + # Matching all parameters and their quant_state items: + d0 = dict(model_0.named_parameters()) + d1 = dict(model_1.named_parameters()) + self.assertTrue(d0.keys() == d1.keys()) + + for k in d0.keys(): + self.assertTrue(d0[k].shape == d1[k].shape) + self.assertTrue(d0[k].device.type == d1[k].device.type) + self.assertTrue(d0[k].device == d1[k].device) + self.assertTrue(d0[k].dtype == d1[k].dtype) + self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device))) + + if isinstance(d0[k], bnb.nn.modules.Params4bit): + for v0, v1 in zip( + d0[k].quant_state.as_dict().values(), + d1[k].quant_state.as_dict().values(), + ): + if isinstance(v0, torch.Tensor): + self.assertTrue(torch.equal(v0, v1.to(v0.device))) + else: + self.assertTrue(v0 == v1) + + # comparing forward() outputs + dummy_inputs = self.get_dummy_inputs() + inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)} + inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs}) + out_0 = model_0(**inputs)[0] + out_1 = model_1(**inputs)[0] + self.assertTrue(torch.equal(out_0, out_1)) + + +class ExtendedSerializationTest(BaseBnb4BitSerializationTests): + """ + tests more combinations of parameters + """ + + def test_nf4_single_unsafe(self): + self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=False) + + def test_nf4_single_safe(self): + self.test_serialization(quant_type="nf4", double_quant=False, safe_serialization=True) + + def test_nf4_double_unsafe(self): + self.test_serialization(quant_type="nf4", double_quant=True, safe_serialization=False) + + # nf4 double safetensors quantization is tested in test_serialization() method from the parent class + + def test_fp4_single_unsafe(self): + self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=False) + + def test_fp4_single_safe(self): + self.test_serialization(quant_type="fp4", double_quant=False, safe_serialization=True) + + def test_fp4_double_unsafe(self): + self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=False) + + def test_fp4_double_safe(self): + self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py new file mode 100644 index 000000000000..ba2402461c87 --- /dev/null +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -0,0 +1,552 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import tempfile +import unittest + +import numpy as np + +from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging +from diffusers.utils.testing_utils import ( + CaptureLogger, + is_bitsandbytes_available, + is_torch_available, + is_transformers_available, + load_pt, + numpy_cosine_similarity_distance, + require_accelerate, + require_bitsandbytes_version_greater, + require_torch, + require_torch_gpu, + require_transformers_version_greater, + slow, + torch_device, +) + + +def get_some_linear_layer(model): + if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: + return model.transformer_blocks[0].attn.to_q + else: + return NotImplementedError("Don't know what layer to retrieve here.") + + +if is_transformers_available(): + from transformers import T5EncoderModel + +if is_torch_available(): + import torch + import torch.nn as nn + + class LoRALayer(nn.Module): + """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only + + Taken from + https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77 + """ + + def __init__(self, module: nn.Module, rank: int): + super().__init__() + self.module = module + self.adapter = nn.Sequential( + nn.Linear(module.in_features, rank, bias=False), + nn.Linear(rank, module.out_features, bias=False), + ) + small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5 + nn.init.normal_(self.adapter[0].weight, std=small_std) + nn.init.zeros_(self.adapter[1].weight) + self.adapter.to(module.weight.device) + + def forward(self, input, *args, **kwargs): + return self.module(input, *args, **kwargs) + self.adapter(input) + + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + + +@require_bitsandbytes_version_greater("0.43.2") +@require_accelerate +@require_torch +@require_torch_gpu +@slow +class Base8bitTests(unittest.TestCase): + # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected) + # Therefore here we use only SD3 to test our module + model_name = "stabilityai/stable-diffusion-3-medium-diffusers" + + # This was obtained on audace so the number might slightly change + expected_rel_difference = 1.94 + + prompt = "a beautiful sunset amidst the mountains." + num_inference_steps = 10 + seed = 0 + + def get_dummy_inputs(self): + prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt" + ) + pooled_prompt_embeds = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/pooled_prompt_embeds.pt" + ) + latent_model_input = load_pt( + "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/latent_model_input.pt" + ) + + input_dict_for_transformer = { + "hidden_states": latent_model_input, + "encoder_hidden_states": prompt_embeds, + "pooled_projections": pooled_prompt_embeds, + "timestep": torch.Tensor([1.0]), + "return_dict": False, + } + return input_dict_for_transformer + + +class BnB8bitBasicTests(Base8bitTests): + def setUp(self): + gc.collect() + torch.cuda.empty_cache() + + # Models + self.model_fp16 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", torch_dtype=torch.float16 + ) + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) + self.model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=mixed_int8_config + ) + + def tearDown(self): + del self.model_fp16 + del self.model_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quantization_num_parameters(self): + r""" + Test if the number of returned parameters is correct + """ + num_params_8bit = self.model_8bit.num_parameters() + num_params_fp16 = self.model_fp16.num_parameters() + + self.assertEqual(num_params_8bit, num_params_fp16) + + def test_quantization_config_json_serialization(self): + r""" + A simple test to check if the quantization config is correctly serialized and deserialized + """ + config = self.model_8bit.config + + self.assertTrue("quantization_config" in config) + + _ = config["quantization_config"].to_dict() + _ = config["quantization_config"].to_diff_dict() + + _ = config["quantization_config"].to_json_string() + + def test_memory_footprint(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + mem_fp16 = self.model_fp16.get_memory_footprint() + mem_8bit = self.model_8bit.get_memory_footprint() + + self.assertAlmostEqual(mem_fp16 / mem_8bit, self.expected_rel_difference, delta=1e-2) + linear = get_some_linear_layer(self.model_8bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + + def test_original_dtype(self): + r""" + A simple test to check if the model succesfully stores the original dtype + """ + self.assertTrue("_pre_quantization_dtype" in self.model_8bit.config) + self.assertFalse("_pre_quantization_dtype" in self.model_fp16.config) + self.assertTrue(self.model_8bit.config["_pre_quantization_dtype"] == torch.float16) + + def test_keep_modules_in_fp32(self): + r""" + A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. + Also ensures if inference works. + """ + fp32_modules = SD3Transformer2DModel._keep_in_fp32_modules + SD3Transformer2DModel._keep_in_fp32_modules = ["proj_out"] + + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) + model = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=mixed_int8_config + ) + + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + if name in model._keep_in_fp32_modules: + self.assertTrue(module.weight.dtype == torch.float32) + else: + # 8-bit parameters are packed in int8 variables + self.assertTrue(module.weight.dtype == torch.int8) + + # test if inference works. + with torch.no_grad() and torch.amp.autocast("cuda", dtype=torch.float16): + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + _ = model(**model_inputs) + + SD3Transformer2DModel._keep_in_fp32_modules = fp32_modules + + def test_linear_are_8bit(self): + r""" + A simple test to check if the model conversion has been done correctly by checking on the + memory footprint of the converted model and the class type of the linear layers of the converted models + """ + self.model_fp16.get_memory_footprint() + self.model_8bit.get_memory_footprint() + + for name, module in self.model_8bit.named_modules(): + if isinstance(module, torch.nn.Linear): + if name not in ["proj_out"]: + # 8-bit parameters are packed in int8 variables + self.assertTrue(module.weight.dtype == torch.int8) + + def test_llm_skip(self): + r""" + A simple test to check if `llm_int8_skip_modules` works as expected + """ + config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["proj_out"]) + model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=config + ) + linear = get_some_linear_layer(model_8bit) + self.assertTrue(linear.weight.dtype == torch.int8) + self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt)) + + self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear)) + self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8) + + def test_config_from_pretrained(self): + transformer_8bit = FluxTransformer2DModel.from_pretrained( + "hf-internal-testing/flux.1-dev-int8-pkg", subfolder="transformer" + ) + linear = get_some_linear_layer(transformer_8bit) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) + + def test_device_and_dtype_assignment(self): + r""" + Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. + Checks also if other models are casted correctly. + """ + with self.assertRaises(ValueError): + # Tries with `str` + self.model_8bit.to("cpu") + + with self.assertRaises(ValueError): + # Tries with a `dtype`` + self.model_8bit.to(torch.float16) + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_8bit.to(torch.device("cuda:0")) + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_8bit.float() + + with self.assertRaises(ValueError): + # Tries with a `device` + self.model_8bit.half() + + # Test if we did not break anything + self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device) + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(dtype=torch.float32, device=torch_device) + for k, v in input_dict_for_transformer.items() + if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + with torch.no_grad(): + _ = self.model_fp16(**model_inputs) + + # Check this does not throw an error + _ = self.model_fp16.to("cpu") + + # Check this does not throw an error + _ = self.model_fp16.half() + + # Check this does not throw an error + _ = self.model_fp16.float() + + # Check that this does not throw an error + _ = self.model_fp16.cuda() + + +class BnB8bitTrainingTests(Base8bitTests): + def setUp(self): + gc.collect() + torch.cuda.empty_cache() + + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) + self.model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=mixed_int8_config + ) + + def test_training(self): + # Step 1: freeze all parameters + for param in self.model_8bit.parameters(): + param.requires_grad = False # freeze the model - train adapters later + if param.ndim == 1: + # cast the small parameters (e.g. layernorm) to fp32 for stability + param.data = param.data.to(torch.float32) + + # Step 2: add adapters + for _, module in self.model_8bit.named_modules(): + if "Attention" in repr(type(module)): + module.to_k = LoRALayer(module.to_k, rank=4) + module.to_q = LoRALayer(module.to_q, rank=4) + module.to_v = LoRALayer(module.to_v, rank=4) + + # Step 3: dummy batch + input_dict_for_transformer = self.get_dummy_inputs() + model_inputs = { + k: v.to(device=torch_device) for k, v in input_dict_for_transformer.items() if not isinstance(v, bool) + } + model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) + + # Step 4: Check if the gradient is not None + with torch.amp.autocast("cuda", dtype=torch.float16): + out = self.model_8bit(**model_inputs)[0] + out.norm().backward() + + for module in self.model_8bit.modules(): + if isinstance(module, LoRALayer): + self.assertTrue(module.adapter[1].weight.grad is not None) + self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) + + +@require_transformers_version_greater("4.44.0") +class SlowBnb8bitTests(Base8bitTests): + def setUp(self) -> None: + gc.collect() + torch.cuda.empty_cache() + + mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True) + model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=mixed_int8_config + ) + self.pipeline_8bit = DiffusionPipeline.from_pretrained( + self.model_name, transformer=model_8bit, torch_dtype=torch.float16 + ) + self.pipeline_8bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quality(self): + output = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0149, 0.0322, 0.0073, 0.0134, 0.0332, 0.011, 0.002, 0.0232, 0.0193]) + + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-2) + + def test_model_cpu_offload_raises_warning(self): + model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True) + ) + pipeline_8bit = DiffusionPipeline.from_pretrained( + self.model_name, transformer=model_8bit, torch_dtype=torch.float16 + ) + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel(30) + + with CaptureLogger(logger) as cap_logger: + pipeline_8bit.enable_model_cpu_offload() + + assert "has been loaded in `bitsandbytes` 8bit" in cap_logger.out + + def test_moving_to_cpu_throws_warning(self): + model_8bit = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=BitsAndBytesConfig(load_in_8bit=True) + ) + logger = logging.get_logger("diffusers.pipelines.pipeline_utils") + logger.setLevel(30) + + with CaptureLogger(logger) as cap_logger: + # Because `model.dtype` will return torch.float16 as SD3 transformer has + # a conv layer as the first layer. + _ = DiffusionPipeline.from_pretrained( + self.model_name, transformer=model_8bit, torch_dtype=torch.float16 + ).to("cpu") + + assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out + + def test_generate_quality_dequantize(self): + r""" + Test that loading the model and unquantize it produce correct results. + """ + self.pipeline_8bit.transformer.dequantize() + output = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0266, 0.0264, 0.0271, 0.0110, 0.0310, 0.0098, 0.0078, 0.0256, 0.0208]) + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-2) + + # 8bit models cannot be offloaded to CPU. + self.assertTrue(self.pipeline_8bit.transformer.device.type == "cuda") + # calling it again shouldn't be a problem + _ = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=2, + generator=torch.manual_seed(self.seed), + output_type="np", + ).images + + +@require_transformers_version_greater("4.44.0") +class SlowBnb8bitFluxTests(Base8bitTests): + def setUp(self) -> None: + gc.collect() + torch.cuda.empty_cache() + + model_id = "hf-internal-testing/flux.1-dev-int8-pkg" + t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2") + transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer") + self.pipeline_8bit = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + text_encoder_2=t5_8bit, + transformer=transformer_8bit, + torch_dtype=torch.float16, + ) + self.pipeline_8bit.enable_model_cpu_offload() + + def tearDown(self): + del self.pipeline_8bit + + gc.collect() + torch.cuda.empty_cache() + + def test_quality(self): + # keep the resolution and max tokens to a lower number for faster execution. + output = self.pipeline_8bit( + prompt=self.prompt, + num_inference_steps=self.num_inference_steps, + generator=torch.manual_seed(self.seed), + height=256, + width=256, + max_sequence_length=64, + output_type="np", + ).images + out_slice = output[0, -3:, -3:, -1].flatten() + expected_slice = np.array([0.0574, 0.0554, 0.0581, 0.0686, 0.0676, 0.0759, 0.0757, 0.0803, 0.0930]) + + max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) + self.assertTrue(max_diff < 1e-3) + + +@slow +class BaseBnb8bitSerializationTests(Base8bitTests): + def setUp(self): + gc.collect() + torch.cuda.empty_cache() + + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + ) + self.model_0 = SD3Transformer2DModel.from_pretrained( + self.model_name, subfolder="transformer", quantization_config=quantization_config + ) + + def tearDown(self): + del self.model_0 + + gc.collect() + torch.cuda.empty_cache() + + def test_serialization(self): + r""" + Test whether it is possible to serialize a model in 8-bit. Uses most typical params as default. + """ + self.assertTrue("_pre_quantization_dtype" in self.model_0.config) + with tempfile.TemporaryDirectory() as tmpdirname: + self.model_0.save_pretrained(tmpdirname) + + config = SD3Transformer2DModel.load_config(tmpdirname) + self.assertTrue("quantization_config" in config) + self.assertTrue("_pre_quantization_dtype" not in config) + + model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname) + + # checking quantized linear module weight + linear = get_some_linear_layer(model_1) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) + + # checking memory footpring + self.assertAlmostEqual(self.model_0.get_memory_footprint() / model_1.get_memory_footprint(), 1, places=2) + + # Matching all parameters and their quant_state items: + d0 = dict(self.model_0.named_parameters()) + d1 = dict(model_1.named_parameters()) + self.assertTrue(d0.keys() == d1.keys()) + + # comparing forward() outputs + dummy_inputs = self.get_dummy_inputs() + inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)} + inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs}) + out_0 = self.model_0(**inputs)[0] + out_1 = model_1(**inputs)[0] + self.assertTrue(torch.equal(out_0, out_1)) + + def test_serialization_sharded(self): + with tempfile.TemporaryDirectory() as tmpdirname: + self.model_0.save_pretrained(tmpdirname, max_shard_size="200MB") + + config = SD3Transformer2DModel.load_config(tmpdirname) + self.assertTrue("quantization_config" in config) + self.assertTrue("_pre_quantization_dtype" not in config) + + model_1 = SD3Transformer2DModel.from_pretrained(tmpdirname) + + # checking quantized linear module weight + linear = get_some_linear_layer(model_1) + self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params) + self.assertTrue(hasattr(linear.weight, "SCB")) + + # comparing forward() outputs + dummy_inputs = self.get_dummy_inputs() + inputs = {k: v.to(torch_device) for k, v in dummy_inputs.items() if isinstance(v, torch.Tensor)} + inputs.update({k: v for k, v in dummy_inputs.items() if k not in inputs}) + out_0 = self.model_0(**inputs)[0] + out_1 = model_1(**inputs)[0] + self.assertTrue(torch.equal(out_0, out_1))