diff --git a/.github/workflows/pr_flax_dependency_test.yml b/.github/workflows/pr_flax_dependency_test.yml
new file mode 100644
index 000000000000..d7d2a2d4c3d5
--- /dev/null
+++ b/.github/workflows/pr_flax_dependency_test.yml
@@ -0,0 +1,34 @@
+name: Run Flax dependency tests
+
+on:
+ pull_request:
+ branches:
+ - main
+ push:
+ branches:
+ - main
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ check_flax_dependencies:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.8"
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -e .
+ pip install "jax[cpu]>=0.2.16,!=0.3.2"
+ pip install "flax>=0.4.1"
+ pip install "jaxlib>=0.1.65"
+ pip install pytest
+ - name: Check for soft dependencies
+ run: |
+ pytest tests/others/test_dependencies.py
diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml
index aaaea147f7ab..f7d9dde5258d 100644
--- a/.github/workflows/pr_tests.yml
+++ b/.github/workflows/pr_tests.yml
@@ -72,7 +72,7 @@ jobs:
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test]
- python -m pip install git+https://github.com/huggingface/accelerate.git
+ python -m pip install accelerate
- name: Environment
run: |
@@ -115,7 +115,7 @@ jobs:
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
- examples/test_examples.py
+ examples/test_examples.py
- name: Failure short reports
if: ${{ failure() }}
diff --git a/.github/workflows/pr_torch_dependency_test.yml b/.github/workflows/pr_torch_dependency_test.yml
new file mode 100644
index 000000000000..57a7a5c77c74
--- /dev/null
+++ b/.github/workflows/pr_torch_dependency_test.yml
@@ -0,0 +1,32 @@
+name: Run Torch dependency tests
+
+on:
+ pull_request:
+ branches:
+ - main
+ push:
+ branches:
+ - main
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
+ cancel-in-progress: true
+
+jobs:
+ check_torch_dependencies:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: "3.8"
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -e .
+ pip install torch torchvision torchaudio
+ pip install pytest
+ - name: Check for soft dependencies
+ run: |
+ pytest tests/others/test_dependencies.py
diff --git a/docs/README.md b/docs/README.md
index 30e5d430765e..f85032c68931 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -16,7 +16,7 @@ limitations under the License.
# Generating the documentation
-To generate the documentation, you first have to build it. Several packages are necessary to build the doc,
+To generate the documentation, you first have to build it. Several packages are necessary to build the doc,
you can install them with the following command, at the root of the code repository:
```bash
@@ -142,7 +142,7 @@ This will include every public method of the pipeline that is documented, as wel
- __call__
- enable_attention_slicing
- disable_attention_slicing
- - enable_xformers_memory_efficient_attention
+ - enable_xformers_memory_efficient_attention
- disable_xformers_memory_efficient_attention
```
@@ -154,7 +154,7 @@ Values that should be put in `code` should either be surrounded by backticks: \`
and objects like True, None, or any strings should usually be put in `code`.
When mentioning a class, function, or method, it is recommended to use our syntax for internal links so that our tool
-adds a link to its documentation with this syntax: \[\`XXXClass\`\] or \[\`function\`\]. This requires the class or
+adds a link to its documentation with this syntax: \[\`XXXClass\`\] or \[\`function\`\]. This requires the class or
function to be in the main package.
If you want to create a link to some internal class or function, you need to
diff --git a/docs/TRANSLATING.md b/docs/TRANSLATING.md
index 32cd95f2ade9..b5a88812f30a 100644
--- a/docs/TRANSLATING.md
+++ b/docs/TRANSLATING.md
@@ -38,7 +38,7 @@ Here, `LANG-ID` should be one of the ISO 639-1 or ISO 639-2 language codes -- se
The fun part comes - translating the text!
-The first thing we recommend is translating the part of the `_toctree.yml` file that corresponds to your doc chapter. This file is used to render the table of contents on the website.
+The first thing we recommend is translating the part of the `_toctree.yml` file that corresponds to your doc chapter. This file is used to render the table of contents on the website.
> 🙋 If the `_toctree.yml` file doesn't yet exist for your language, you can create one by copy-pasting from the English version and deleting the sections unrelated to your chapter. Just make sure it exists in the `docs/source/LANG-ID/` directory!
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 3626db3f7b58..a0c6159991b5 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -72,6 +72,8 @@
title: Overview
- local: using-diffusers/sdxl
title: Stable Diffusion XL
+ - local: using-diffusers/lcm
+ title: Latent Consistency Models
- local: using-diffusers/kandinsky
title: Kandinsky
- local: using-diffusers/controlnet
@@ -133,7 +135,7 @@
- local: optimization/memory
title: Reduce memory usage
- local: optimization/torch2.0
- title: Torch 2.0
+ title: PyTorch 2.0
- local: optimization/xformers
title: xFormers
- local: optimization/tome
@@ -200,6 +202,8 @@
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_tiny
title: Tiny AutoEncoder
+ - local: api/models/consistency_decoder_vae
+ title: ConsistencyDecoderVAE
- local: api/models/transformer2d
title: Transformer2D
- local: api/models/transformer_temporal
@@ -344,6 +348,8 @@
title: Overview
- local: api/schedulers/cm_stochastic_iterative
title: CMStochasticIterativeScheduler
+ - local: api/schedulers/consistency_decoder
+ title: ConsistencyDecoderScheduler
- local: api/schedulers/ddim_inverse
title: DDIMInverseScheduler
- local: api/schedulers/ddim
diff --git a/docs/source/en/api/models/consistency_decoder_vae.md b/docs/source/en/api/models/consistency_decoder_vae.md
new file mode 100644
index 000000000000..b45f7fa059dc
--- /dev/null
+++ b/docs/source/en/api/models/consistency_decoder_vae.md
@@ -0,0 +1,18 @@
+# Consistency Decoder
+
+Consistency decoder can be used to decode the latents from the denoising UNet in the [`StableDiffusionPipeline`]. This decoder was introduced in the [DALL-E 3 technical report](https://openai.com/dall-e-3).
+
+The original codebase can be found at [openai/consistencydecoder](https://github.com/openai/consistencydecoder).
+
+
+
+Inference is only supported for 2 iterations as of now.
+
+
+
+The pipeline could not have been contributed without the help of [madebyollin](https://github.com/madebyollin) and [mrsteyk](https://github.com/mrsteyk) from [this issue](https://github.com/openai/consistencydecoder/issues/1).
+
+## ConsistencyDecoderVAE
+[[autodoc]] ConsistencyDecoderVAE
+ - all
+ - decode
diff --git a/docs/source/en/api/schedulers/consistency_decoder.md b/docs/source/en/api/schedulers/consistency_decoder.md
new file mode 100644
index 000000000000..6c937b913279
--- /dev/null
+++ b/docs/source/en/api/schedulers/consistency_decoder.md
@@ -0,0 +1,9 @@
+# ConsistencyDecoderScheduler
+
+This scheduler is a part of the [`ConsistencyDecoderPipeline`] and was introduced in [DALL-E 3](https://openai.com/dall-e-3).
+
+The original codebase can be found at [openai/consistency_models](https://github.com/openai/consistency_models).
+
+
+## ConsistencyDecoderScheduler
+[[autodoc]] schedulers.scheduling_consistency_decoder.ConsistencyDecoderScheduler
\ No newline at end of file
diff --git a/docs/source/en/conceptual/ethical_guidelines.md b/docs/source/en/conceptual/ethical_guidelines.md
index fe1d849f44ff..86176bcaa33e 100644
--- a/docs/source/en/conceptual/ethical_guidelines.md
+++ b/docs/source/en/conceptual/ethical_guidelines.md
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
## Preamble
-[Diffusers](https://huggingface.co/docs/diffusers/index) provides pre-trained diffusion models and serves as a modular toolbox for inference and training.
+[Diffusers](https://huggingface.co/docs/diffusers/index) provides pre-trained diffusion models and serves as a modular toolbox for inference and training.
Given its real case applications in the world and potential negative impacts on society, we think it is important to provide the project with ethical guidelines to guide the development, users’ contributions, and usage of the Diffusers library.
@@ -46,7 +46,7 @@ The following ethical guidelines apply generally, but we will primarily implemen
## Examples of implementations: Safety features and Mechanisms
-The team works daily to make the technical and non-technical tools available to deal with the potential ethical and social risks associated with diffusion technology. Moreover, the community's input is invaluable in ensuring these features' implementation and raising awareness with us.
+The team works daily to make the technical and non-technical tools available to deal with the potential ethical and social risks associated with diffusion technology. Moreover, the community's input is invaluable in ensuring these features' implementation and raising awareness with us.
- [**Community tab**](https://huggingface.co/docs/hub/repositories-pull-requests-discussions): it enables the community to discuss and better collaborate on a project.
@@ -60,4 +60,4 @@ The team works daily to make the technical and non-technical tools available to
- **Staged released on the Hub**: in particularly sensitive situations, access to some repositories should be restricted. This staged release is an intermediary step that allows the repository’s authors to have more control over its use.
-- **Licensing**: [OpenRAILs](https://huggingface.co/blog/open_rail), a new type of licensing, allow us to ensure free access while having a set of restrictions that ensure more responsible use.
+- **Licensing**: [OpenRAILs](https://huggingface.co/blog/open_rail), a new type of licensing, allow us to ensure free access while having a set of restrictions that ensure more responsible use.
diff --git a/docs/source/en/conceptual/evaluation.md b/docs/source/en/conceptual/evaluation.md
index 997c5f4016dc..848eec8620cd 100644
--- a/docs/source/en/conceptual/evaluation.md
+++ b/docs/source/en/conceptual/evaluation.md
@@ -12,9 +12,9 @@ specific language governing permissions and limitations under the License.
# Evaluating Diffusion Models
-
-
-
+
+
+
Evaluation of generative models like [Stable Diffusion](https://huggingface.co/docs/diffusers/stable_diffusion) is subjective in nature. But as practitioners and researchers, we often have to make careful choices amongst many different possibilities. So, when working with different generative models (like GANs, Diffusion, etc.), how do we choose one over the other?
@@ -23,7 +23,7 @@ However, quantitative metrics don't necessarily correspond to image quality. So,
of both qualitative and quantitative evaluations provides a stronger signal when choosing one model
over the other.
-In this document, we provide a non-exhaustive overview of qualitative and quantitative methods to evaluate Diffusion models. For quantitative methods, we specifically focus on how to implement them alongside `diffusers`.
+In this document, we provide a non-exhaustive overview of qualitative and quantitative methods to evaluate Diffusion models. For quantitative methods, we specifically focus on how to implement them alongside `diffusers`.
The methods shown in this document can also be used to evaluate different [noise schedulers](https://huggingface.co/docs/diffusers/main/en/api/schedulers/overview) keeping the underlying generation model fixed.
@@ -38,9 +38,9 @@ We cover Diffusion models with the following pipelines:
## Qualitative Evaluation
Qualitative evaluation typically involves human assessment of generated images. Quality is measured across aspects such as compositionality, image-text alignment, and spatial relations. Common prompts provide a degree of uniformity for subjective metrics.
-DrawBench and PartiPrompts are prompt datasets used for qualitative benchmarking. DrawBench and PartiPrompts were introduced by [Imagen](https://imagen.research.google/) and [Parti](https://parti.research.google/) respectively.
+DrawBench and PartiPrompts are prompt datasets used for qualitative benchmarking. DrawBench and PartiPrompts were introduced by [Imagen](https://imagen.research.google/) and [Parti](https://parti.research.google/) respectively.
-From the [official Parti website](https://parti.research.google/):
+From the [official Parti website](https://parti.research.google/):
> PartiPrompts (P2) is a rich set of over 1600 prompts in English that we release as part of this work. P2 can be used to measure model capabilities across various categories and challenge aspects.
@@ -52,13 +52,13 @@ PartiPrompts has the following columns:
- Category of the prompt (such as “Abstract”, “World Knowledge”, etc.)
- Challenge reflecting the difficulty (such as “Basic”, “Complex”, “Writing & Symbols”, etc.)
-These benchmarks allow for side-by-side human evaluation of different image generation models.
+These benchmarks allow for side-by-side human evaluation of different image generation models.
For this, the 🧨 Diffusers team has built **Open Parti Prompts**, which is a community-driven qualitative benchmark based on Parti Prompts to compare state-of-the-art open-source diffusion models:
- [Open Parti Prompts Game](https://huggingface.co/spaces/OpenGenAI/open-parti-prompts): For 10 parti prompts, 4 generated images are shown and the user selects the image that suits the prompt best.
- [Open Parti Prompts Leaderboard](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard): The leaderboard comparing the currently best open-sourced diffusion models to each other.
-To manually compare images, let’s see how we can use `diffusers` on a couple of PartiPrompts.
+To manually compare images, let’s see how we can use `diffusers` on a couple of PartiPrompts.
Below we show some prompts sampled across different challenges: Basic, Complex, Linguistic Structures, Imagination, and Writing & Symbols. Here we are using PartiPrompts as a [dataset](https://huggingface.co/datasets/nateraw/parti-prompts).
@@ -92,16 +92,16 @@ images = sd_pipeline(sample_prompts, num_images_per_prompt=1, generator=generato
![parti-prompts-14](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts-14.png)
-We can also set `num_images_per_prompt` accordingly to compare different images for the same prompt. Running the same pipeline but with a different checkpoint ([v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)), yields:
+We can also set `num_images_per_prompt` accordingly to compare different images for the same prompt. Running the same pipeline but with a different checkpoint ([v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)), yields:
![parti-prompts-15](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts-15.png)
Once several images are generated from all the prompts using multiple models (under evaluation), these results are presented to human evaluators for scoring. For
-more details on the DrawBench and PartiPrompts benchmarks, refer to their respective papers.
+more details on the DrawBench and PartiPrompts benchmarks, refer to their respective papers.
-
+
-It is useful to look at some inference samples while a model is training to measure the
+It is useful to look at some inference samples while a model is training to measure the
training progress. In our [training scripts](https://github.com/huggingface/diffusers/tree/main/examples/), we support this utility with additional support for
logging to TensorBoard and Weights & Biases.
@@ -177,7 +177,7 @@ generator = torch.manual_seed(seed)
images = sd_pipeline(prompts, num_images_per_prompt=1, generator=generator, output_type="np").images
```
-Then we load the [v1-5 checkpoint](https://huggingface.co/runwayml/stable-diffusion-v1-5) to generate images:
+Then we load the [v1-5 checkpoint](https://huggingface.co/runwayml/stable-diffusion-v1-5) to generate images:
```python
model_ckpt_1_5 = "runwayml/stable-diffusion-v1-5"
@@ -205,7 +205,7 @@ It seems like the [v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)
By construction, there are some limitations in this score. The captions in the training dataset
were crawled from the web and extracted from `alt` and similar tags associated an image on the internet.
They are not necessarily representative of what a human being would use to describe an image. Hence we
-had to "engineer" some prompts here.
+had to "engineer" some prompts here.
@@ -551,15 +551,15 @@ FID results tend to be fragile as they depend on a lot of factors:
* The implementation accuracy of the computation.
* The image format (not the same if we start from PNGs vs JPGs).
-Keeping that in mind, FID is often most useful when comparing similar runs, but it is
-hard to reproduce paper results unless the authors carefully disclose the FID
+Keeping that in mind, FID is often most useful when comparing similar runs, but it is
+hard to reproduce paper results unless the authors carefully disclose the FID
measurement code.
-These points apply to other related metrics too, such as KID and IS.
+These points apply to other related metrics too, such as KID and IS.
-As a final step, let's visually inspect the `fake_images`.
+As a final step, let's visually inspect the `fake_images`.
diff --git a/docs/source/en/conceptual/philosophy.md b/docs/source/en/conceptual/philosophy.md
index 909ed6bc193d..c7b96abd7f11 100644
--- a/docs/source/en/conceptual/philosophy.md
+++ b/docs/source/en/conceptual/philosophy.md
@@ -27,18 +27,18 @@ In a nutshell, Diffusers is built to be a natural extension of PyTorch. Therefor
## Simple over easy
-As PyTorch states, **explicit is better than implicit** and **simple is better than complex**. This design philosophy is reflected in multiple parts of the library:
+As PyTorch states, **explicit is better than implicit** and **simple is better than complex**. This design philosophy is reflected in multiple parts of the library:
- We follow PyTorch's API with methods like [`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to) to let the user handle device management.
- Raising concise error messages is preferred to silently correct erroneous input. Diffusers aims at teaching the user, rather than making the library as easy to use as possible.
- Complex model vs. scheduler logic is exposed instead of magically handled inside. Schedulers/Samplers are separated from diffusion models with minimal dependencies on each other. This forces the user to write the unrolled denoising loop. However, the separation allows for easier debugging and gives the user more control over adapting the denoising process or switching out diffusion models or schedulers.
-- Separately trained components of the diffusion pipeline, *e.g.* the text encoder, the unet, and the variational autoencoder, each have their own model class. This forces the user to handle the interaction between the different model components, and the serialization format separates the model components into different files. However, this allows for easier debugging and customization. DreamBooth or Textual Inversion training
+- Separately trained components of the diffusion pipeline, *e.g.* the text encoder, the unet, and the variational autoencoder, each have their own model class. This forces the user to handle the interaction between the different model components, and the serialization format separates the model components into different files. However, this allows for easier debugging and customization. DreamBooth or Textual Inversion training
is very simple thanks to Diffusers' ability to separate single components of the diffusion pipeline.
## Tweakable, contributor-friendly over abstraction
-For large parts of the library, Diffusers adopts an important design principle of the [Transformers library](https://github.com/huggingface/transformers), which is to prefer copy-pasted code over hasty abstractions. This design principle is very opinionated and stands in stark contrast to popular design principles such as [Don't repeat yourself (DRY)](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself).
+For large parts of the library, Diffusers adopts an important design principle of the [Transformers library](https://github.com/huggingface/transformers), which is to prefer copy-pasted code over hasty abstractions. This design principle is very opinionated and stands in stark contrast to popular design principles such as [Don't repeat yourself (DRY)](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself).
In short, just like Transformers does for modeling files, Diffusers prefers to keep an extremely low level of abstraction and very self-contained code for pipelines and schedulers.
-Functions, long code blocks, and even classes can be copied across multiple files which at first can look like a bad, sloppy design choice that makes the library unmaintainable.
+Functions, long code blocks, and even classes can be copied across multiple files which at first can look like a bad, sloppy design choice that makes the library unmaintainable.
**However**, this design has proven to be extremely successful for Transformers and makes a lot of sense for community-driven, open-source machine learning libraries because:
- Machine Learning is an extremely fast-moving field in which paradigms, model architectures, and algorithms are changing rapidly, which therefore makes it very difficult to define long-lasting code abstractions.
- Machine Learning practitioners like to be able to quickly tweak existing code for ideation and research and therefore prefer self-contained code over one that contains many abstractions.
@@ -47,10 +47,10 @@ Functions, long code blocks, and even classes can be copied across multiple file
At Hugging Face, we call this design the **single-file policy** which means that almost all of the code of a certain class should be written in a single, self-contained file. To read more about the philosophy, you can have a look
at [this blog post](https://huggingface.co/blog/transformers-design-philosophy).
-In Diffusers, we follow this philosophy for both pipelines and schedulers, but only partly for diffusion models. The reason we don't follow this design fully for diffusion models is because almost all diffusion pipelines, such
+In Diffusers, we follow this philosophy for both pipelines and schedulers, but only partly for diffusion models. The reason we don't follow this design fully for diffusion models is because almost all diffusion pipelines, such
as [DDPM](https://huggingface.co/docs/diffusers/api/pipelines/ddpm), [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview#stable-diffusion-pipelines), [unCLIP (DALL·E 2)](https://huggingface.co/docs/diffusers/api/pipelines/unclip) and [Imagen](https://imagen.research.google/) all rely on the same diffusion model, the [UNet](https://huggingface.co/docs/diffusers/api/models/unet2d-cond).
-Great, now you should have generally understood why 🧨 Diffusers is designed the way it is 🤗.
+Great, now you should have generally understood why 🧨 Diffusers is designed the way it is 🤗.
We try to apply these design principles consistently across the library. Nevertheless, there are some minor exceptions to the philosophy or some unlucky design choices. If you have feedback regarding the design, we would ❤️ to hear it [directly on GitHub](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=).
## Design Philosophy in Details
@@ -89,7 +89,7 @@ The following design principles are followed:
- Models should by default have the highest precision and lowest performance setting.
- To integrate new model checkpoints whose general architecture can be classified as an architecture that already exists in Diffusers, the existing model architecture shall be adapted to make it work with the new checkpoint. One should only create a new file if the model architecture is fundamentally different.
- Models should be designed to be easily extendable to future changes. This can be achieved by limiting public function arguments, configuration arguments, and "foreseeing" future changes, *e.g.* it is usually better to add `string` "...type" arguments that can easily be extended to new future types instead of boolean `is_..._type` arguments. Only the minimum amount of changes shall be made to existing architectures to make a new model checkpoint work.
-- The model design is a difficult trade-off between keeping code readable and concise and supporting many model checkpoints. For most parts of the modeling code, classes shall be adapted for new model checkpoints, while there are some exceptions where it is preferred to add new classes to make sure the code is kept concise and
+- The model design is a difficult trade-off between keeping code readable and concise and supporting many model checkpoints. For most parts of the modeling code, classes shall be adapted for new model checkpoints, while there are some exceptions where it is preferred to add new classes to make sure the code is kept concise and
readable long-term, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
### Schedulers
@@ -97,9 +97,9 @@ readable long-term, such as [UNet blocks](https://github.com/huggingface/diffuse
Schedulers are responsible to guide the denoising process for inference as well as to define a noise schedule for training. They are designed as individual classes with loadable configuration files and strongly follow the **single-file policy**.
The following design principles are followed:
-- All schedulers are found in [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).
-- Schedulers are **not** allowed to import from large utils files and shall be kept very self-contained.
-- One scheduler Python file corresponds to one scheduler algorithm (as might be defined in a paper).
+- All schedulers are found in [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).
+- Schedulers are **not** allowed to import from large utils files and shall be kept very self-contained.
+- One scheduler Python file corresponds to one scheduler algorithm (as might be defined in a paper).
- If schedulers share similar functionalities, we can make use of the `#Copied from` mechanism.
- Schedulers all inherit from `SchedulerMixin` and `ConfigMixin`.
- Schedulers can be easily swapped out with the [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) method as explained in detail [here](../using-diffusers/schedulers.md).
diff --git a/docs/source/en/optimization/coreml.md b/docs/source/en/optimization/coreml.md
index ab96eea0fb04..62809305bfb0 100644
--- a/docs/source/en/optimization/coreml.md
+++ b/docs/source/en/optimization/coreml.md
@@ -31,7 +31,7 @@ Thankfully, Apple engineers developed [a conversion tool](https://github.com/app
Before you convert a model, though, take a moment to explore the Hugging Face Hub – chances are the model you're interested in is already available in Core ML format:
- the [Apple](https://huggingface.co/apple) organization includes Stable Diffusion versions 1.4, 1.5, 2.0 base, and 2.1 base
-- [coreml](https://huggingface.co/coreml) organization includes custom DreamBoothed and finetuned models
+- [coreml community](https://huggingface.co/coreml-community) includes custom finetuned models
- use this [filter](https://huggingface.co/models?pipeline_tag=text-to-image&library=coreml&p=2&sort=likes) to return all available Core ML checkpoints
If you can't find the model you're interested in, we recommend you follow the instructions for [Converting Models to Core ML](https://github.com/apple/ml-stable-diffusion#-converting-models-to-core-ml) by Apple.
@@ -90,7 +90,6 @@ snapshot_download(repo_id, allow_patterns=f"{variant}/*", local_dir=model_path,
print(f"Model downloaded at {model_path}")
```
-
### Inference[[python-inference]]
Once you have downloaded a snapshot of the model, you can test it using Apple's Python script.
@@ -99,7 +98,7 @@ Once you have downloaded a snapshot of the model, you can test it using Apple's
python -m python_coreml_stable_diffusion.pipeline --prompt "a photo of an astronaut riding a horse on mars" -i models/coreml-stable-diffusion-v1-4_original_packages -o --compute-unit CPU_AND_GPU --seed 93
```
-`` should point to the checkpoint you downloaded in the step above, and `--compute-unit` indicates the hardware you want to allow for inference. It must be one of the following options: `ALL`, `CPU_AND_GPU`, `CPU_ONLY`, `CPU_AND_NE`. You may also provide an optional output path, and a seed for reproducibility.
+Pass the path of the downloaded checkpoint with `-i` flag to the script. `--compute-unit` indicates the hardware you want to allow for inference. It must be one of the following options: `ALL`, `CPU_AND_GPU`, `CPU_ONLY`, `CPU_AND_NE`. You may also provide an optional output path, and a seed for reproducibility.
The inference script assumes you're using the original version of the Stable Diffusion model, `CompVis/stable-diffusion-v1-4`. If you use another model, you *have* to specify its Hub id in the inference command line, using the `--model-version` option. This works for models already supported and custom models you trained or fine-tuned yourself.
@@ -109,7 +108,6 @@ For example, if you want to use [`runwayml/stable-diffusion-v1-5`](https://huggi
python -m python_coreml_stable_diffusion.pipeline --prompt "a photo of an astronaut riding a horse on mars" --compute-unit ALL -o output --seed 93 -i models/coreml-stable-diffusion-v1-5_original_packages --model-version runwayml/stable-diffusion-v1-5
```
-
## Core ML inference in Swift
Running inference in Swift is slightly faster than in Python because the models are already compiled in the `mlmodelc` format. This is noticeable on app startup when the model is loaded but shouldn’t be noticeable if you run several generations afterward.
@@ -149,7 +147,6 @@ You have to specify in `--resource-path` one of the checkpoints downloaded in th
For more details, please refer to the [instructions in Apple's repo](https://github.com/apple/ml-stable-diffusion).
-
## Supported Diffusers Features
The Core ML models and inference code don't support many of the features, options, and flexibility of 🧨 Diffusers. These are some of the limitations to keep in mind:
@@ -158,10 +155,10 @@ The Core ML models and inference code don't support many of the features, option
- Only two schedulers have been ported to Swift, the default one used by Stable Diffusion and `DPMSolverMultistepScheduler`, which we ported to Swift from our `diffusers` implementation. We recommend you use `DPMSolverMultistepScheduler`, since it produces the same quality in about half the steps.
- Negative prompts, classifier-free guidance scale, and image-to-image tasks are available in the inference code. Advanced features such as depth guidance, ControlNet, and latent upscalers are not available yet.
-Apple's [conversion and inference repo](https://github.com/apple/ml-stable-diffusion) and our own [swift-coreml-diffusers](https://github.com/huggingface/swift-coreml-diffusers) repos are intended as technology demonstrators to enable other developers to build upon.
+Apple's [conversion and inference repo](https://github.com/apple/ml-stable-diffusion) and our own [swift-coreml-diffusers](https://github.com/huggingface/swift-coreml-diffusers) repos are intended as technology demonstrators to enable other developers to build upon.
-If you feel strongly about any missing features, please feel free to open a feature request or, better yet, a contribution PR :)
+If you feel strongly about any missing features, please feel free to open a feature request or, better yet, a contribution PR 🙂.
## Native Diffusers Swift app
-One easy way to run Stable Diffusion on your own Apple hardware is to use [our open-source Swift repo](https://github.com/huggingface/swift-coreml-diffusers), based on `diffusers` and Apple's conversion and inference repo. You can study the code, compile it with [Xcode](https://developer.apple.com/xcode/) and adapt it for your own needs. For your convenience, there's also a [standalone Mac app in the App Store](https://apps.apple.com/app/diffusers/id1666309574), so you can play with it without having to deal with the code or IDE. If you are a developer and have determined that Core ML is the best solution to build your Stable Diffusion app, then you can use the rest of this guide to get started with your project. We can't wait to see what you'll build :)
+One easy way to run Stable Diffusion on your own Apple hardware is to use [our open-source Swift repo](https://github.com/huggingface/swift-coreml-diffusers), based on `diffusers` and Apple's conversion and inference repo. You can study the code, compile it with [Xcode](https://developer.apple.com/xcode/) and adapt it for your own needs. For your convenience, there's also a [standalone Mac app in the App Store](https://apps.apple.com/app/diffusers/id1666309574), so you can play with it without having to deal with the code or IDE. If you are a developer and have determined that Core ML is the best solution to build your Stable Diffusion app, then you can use the rest of this guide to get started with your project. We can't wait to see what you'll build 🙂.
diff --git a/docs/source/en/optimization/fp16.md b/docs/source/en/optimization/fp16.md
index 2ac16786eb46..61bc5569c53c 100644
--- a/docs/source/en/optimization/fp16.md
+++ b/docs/source/en/optimization/fp16.md
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# Speed up inference
-There are several ways to optimize 🤗 Diffusers for inference speed. As a general rule of thumb, we recommend using either [xFormers](xformers) or `torch.nn.functional.scaled_dot_product_attention` in PyTorch 2.0 for their memory-efficient attention.
+There are several ways to optimize 🤗 Diffusers for inference speed. As a general rule of thumb, we recommend using either [xFormers](xformers) or `torch.nn.functional.scaled_dot_product_attention` in PyTorch 2.0 for their memory-efficient attention.
@@ -64,5 +64,5 @@ image = pipe(prompt).images[0]
Don't use [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) in any of the pipelines as it can lead to black images and is always slower than pure float16 precision.
-
-
\ No newline at end of file
+
+
diff --git a/docs/source/en/optimization/habana.md b/docs/source/en/optimization/habana.md
index c78c8ca3a1be..8a06210996f3 100644
--- a/docs/source/en/optimization/habana.md
+++ b/docs/source/en/optimization/habana.md
@@ -55,8 +55,7 @@ outputs = pipeline(
)
```
-For more information, check out 🤗 Optimum Habana's [documentation](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion) and the [example](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) provided in the official Github repository.
-
+For more information, check out 🤗 Optimum Habana's [documentation](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion) and the [example](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion) provided in the official GitHub repository.
## Benchmark
diff --git a/docs/source/en/optimization/memory.md b/docs/source/en/optimization/memory.md
index c91fed1b2784..281b65df8d8c 100644
--- a/docs/source/en/optimization/memory.md
+++ b/docs/source/en/optimization/memory.md
@@ -1,3 +1,15 @@
+
+
# Reduce memory usage
A barrier to using diffusion models is the large amount of memory required. To overcome this challenge, there are several memory-reducing techniques you can use to run even some of the largest models on free-tier or consumer GPUs. Some of these techniques can even be combined to further reduce memory usage.
@@ -18,10 +30,9 @@ The results below are obtained from generating a single 512x512 image from the p
| traced UNet | 3.21s | x2.96 |
| memory-efficient attention | 2.63s | x3.61 |
-
## Sliced VAE
-Sliced VAE enables decoding large batches of images with limited VRAM or batches with 32 images or more by decoding the batches of latents one image at a time. You'll likely want to couple this with [`~ModelMixin.enable_xformers_memory_efficient_attention`] to further reduce memory use.
+Sliced VAE enables decoding large batches of images with limited VRAM or batches with 32 images or more by decoding the batches of latents one image at a time. You'll likely want to couple this with [`~ModelMixin.enable_xformers_memory_efficient_attention`] to reduce memory use further if you have xFormers installed.
To use sliced VAE, call [`~StableDiffusionPipeline.enable_vae_slicing`] on your pipeline before inference:
@@ -38,6 +49,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_vae_slicing()
+#pipe.enable_xformers_memory_efficient_attention()
images = pipe([prompt] * 32).images
```
@@ -45,7 +57,7 @@ You may see a small performance boost in VAE decoding on multi-image batches, an
## Tiled VAE
-Tiled VAE processing also enables working with large images on limited VRAM (for example, generating 4k images on 8GB of VRAM) by splitting the image into overlapping tiles, decoding the tiles, and then blending the outputs together to compose the final image. You should also used tiled VAE with [`~ModelMixin.enable_xformers_memory_efficient_attention`] to further reduce memory use.
+Tiled VAE processing also enables working with large images on limited VRAM (for example, generating 4k images on 8GB of VRAM) by splitting the image into overlapping tiles, decoding the tiles, and then blending the outputs together to compose the final image. You should also used tiled VAE with [`~ModelMixin.enable_xformers_memory_efficient_attention`] to reduce memory use further if you have xFormers installed.
To use tiled VAE processing, call [`~StableDiffusionPipeline.enable_vae_tiling`] on your pipeline before inference:
@@ -62,7 +74,7 @@ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "a beautiful landscape photograph"
pipe.enable_vae_tiling()
-pipe.enable_xformers_memory_efficient_attention()
+#pipe.enable_xformers_memory_efficient_attention()
image = pipe([prompt], width=3840, height=2224, num_inference_steps=20).images[0]
```
@@ -98,24 +110,6 @@ Consider using [model offloading](#model-offloading) if you want to optimize for
-CPU offloading can also be chained with attention slicing to reduce memory consumption to less than 2GB.
-
-```Python
-import torch
-from diffusers import StableDiffusionPipeline
-
-pipe = StableDiffusionPipeline.from_pretrained(
- "runwayml/stable-diffusion-v1-5",
- torch_dtype=torch.float16,
- use_safetensors=True,
-)
-
-prompt = "a photo of an astronaut riding a horse on mars"
-pipe.enable_sequential_cpu_offload()
-
-image = pipe(prompt).images[0]
-```
-
When using [`~StableDiffusionPipeline.enable_sequential_cpu_offload`], don't move the pipeline to CUDA beforehand or else the gain in memory consumption will only be minimal (see this [issue](https://github.com/huggingface/diffusers/issues/1934) for more information).
@@ -145,23 +139,6 @@ Enable model offloading by calling [`~StableDiffusionPipeline.enable_model_cpu_o
import torch
from diffusers import StableDiffusionPipeline
-pipe = StableDiffusionPipeline.from_pretrained(
- "runwayml/stable-diffusion-v1-5",
- torch_dtype=torch.float16,
- use_safetensors=True,
-)
-
-prompt = "a photo of an astronaut riding a horse on mars"
-pipe.enable_model_cpu_offload()
-image = pipe(prompt).images[0]
-```
-
-Model offloading can also be combined with attention slicing for additional memory savings.
-
-```Python
-import torch
-from diffusers import StableDiffusionPipeline
-
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
@@ -170,14 +147,12 @@ pipe = StableDiffusionPipeline.from_pretrained(
prompt = "a photo of an astronaut riding a horse on mars"
pipe.enable_model_cpu_offload()
-
image = pipe(prompt).images[0]
```
-In order to properly offload models after they're called, it is required to run the entire pipeline and models are called in the pipeline's expected order. Exercise caution if models are reused outside the context of the pipeline after hooks have been installed. See [Removing Hooks](https://huggingface.co/docs/accelerate/en/package_reference/big_modeling#accelerate.hooks.remove_hook_from_module)
-for more information.
+In order to properly offload models after they're called, it is required to run the entire pipeline and models are called in the pipeline's expected order. Exercise caution if models are reused outside the context of the pipeline after hooks have been installed. See [Removing Hooks](https://huggingface.co/docs/accelerate/en/package_reference/big_modeling#accelerate.hooks.remove_hook_from_module) for more information.
[`~StableDiffusionPipeline.enable_model_cpu_offload`] is a stateful operation that installs hooks on the models and state on the pipeline.
@@ -303,7 +278,7 @@ unet_traced = torch.jit.load("unet_traced.pt")
class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
- self.in_channels = pipe.unet.in_channels
+ self.in_channels = pipe.unet.config.in_channels
self.device = pipe.unet.device
def forward(self, latent_model_input, t, encoder_hidden_states):
@@ -319,7 +294,7 @@ with torch.inference_mode():
## Memory-efficient attention
-Recent work on optimizing bandwidth in the attention block has generated huge speed-ups and reductions in GPU memory usage. The most recent type of memory-efficient attention is [Flash Attention](https://arxiv.org/pdf/2205.14135.pdf) (you can check out the original code at [HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)).
+Recent work on optimizing bandwidth in the attention block has generated huge speed-ups and reductions in GPU memory usage. The most recent type of memory-efficient attention is [Flash Attention](https://arxiv.org/abs/2205.14135) (you can check out the original code at [HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)).
@@ -354,4 +329,4 @@ with torch.inference_mode():
# pipe.disable_xformers_memory_efficient_attention()
```
-The iteration speed when using `xformers` should match the iteration speed of Torch 2.0 as described [here](torch2.0).
+The iteration speed when using `xformers` should match the iteration speed of PyTorch 2.0 as described [here](torch2.0).
diff --git a/docs/source/en/optimization/mps.md b/docs/source/en/optimization/mps.md
index 138c85b51184..f5ce3332fc90 100644
--- a/docs/source/en/optimization/mps.md
+++ b/docs/source/en/optimization/mps.md
@@ -31,6 +31,8 @@ pipe = pipe.to("mps")
pipe.enable_attention_slicing()
prompt = "a photo of an astronaut riding a horse on mars"
+image = pipe(prompt).images[0]
+image
```
@@ -48,10 +50,10 @@ If you're using **PyTorch 1.13**, you need to "prime" the pipeline with an addit
pipe.enable_attention_slicing()
prompt = "a photo of an astronaut riding a horse on mars"
-# First-time "warmup" pass if PyTorch version is 1.13
+ # First-time "warmup" pass if PyTorch version is 1.13
+ _ = pipe(prompt, num_inference_steps=1)
-# Results match those from the CPU device after the warmup pass.
+ # Results match those from the CPU device after the warmup pass.
image = pipe(prompt).images[0]
```
@@ -63,6 +65,7 @@ To prevent this from happening, we recommend *attention slicing* to reduce memor
```py
from diffusers import DiffusionPipeline
+import torch
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("mps")
pipeline.enable_attention_slicing()
diff --git a/docs/source/en/optimization/onnx.md b/docs/source/en/optimization/onnx.md
index 20104b555543..4d352480a007 100644
--- a/docs/source/en/optimization/onnx.md
+++ b/docs/source/en/optimization/onnx.md
@@ -10,13 +10,12 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-
# ONNX Runtime
🤗 [Optimum](https://github.com/huggingface/optimum) provides a Stable Diffusion pipeline compatible with ONNX Runtime. You'll need to install 🤗 Optimum with the following command for ONNX Runtime support:
```bash
-pip install optimum["onnxruntime"]
+pip install -q optimum["onnxruntime"]
```
This guide will show you how to use the Stable Diffusion and Stable Diffusion XL (SDXL) pipelines with ONNX Runtime.
@@ -50,7 +49,7 @@ optimum-cli export onnx --model runwayml/stable-diffusion-v1-5 sd_v15_onnx/
Then to perform inference (you don't have to specify `export=True` again):
-```python
+```python
from optimum.onnxruntime import ORTStableDiffusionPipeline
model_id = "sd_v15_onnx"
diff --git a/docs/source/en/optimization/open_vino.md b/docs/source/en/optimization/open_vino.md
index 606c2207bcda..29299786118a 100644
--- a/docs/source/en/optimization/open_vino.md
+++ b/docs/source/en/optimization/open_vino.md
@@ -10,14 +10,13 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-
# OpenVINO
-🤗 [Optimum](https://github.com/huggingface/optimum-intel) provides Stable Diffusion pipelines compatible with OpenVINO to perform inference on a variety of Intel processors (see the [full list]((https://docs.openvino.ai/latest/openvino_docs_OV_UG_supported_plugins_Supported_Devices.html)) of supported devices).
+🤗 [Optimum](https://github.com/huggingface/optimum-intel) provides Stable Diffusion pipelines compatible with OpenVINO to perform inference on a variety of Intel processors (see the [full list](https://docs.openvino.ai/latest/openvino_docs_OV_UG_supported_plugins_Supported_Devices.html) of supported devices).
You'll need to install 🤗 Optimum Intel with the `--upgrade-strategy eager` option to ensure [`optimum-intel`](https://github.com/huggingface/optimum-intel) is using the latest version:
-```
+```bash
pip install --upgrade-strategy eager optimum["openvino"]
```
diff --git a/docs/source/en/optimization/tome.md b/docs/source/en/optimization/tome.md
index 66d69c6900cc..34726a4c79c2 100644
--- a/docs/source/en/optimization/tome.md
+++ b/docs/source/en/optimization/tome.md
@@ -14,18 +14,25 @@ specific language governing permissions and limitations under the License.
[Token merging](https://huggingface.co/papers/2303.17604) (ToMe) merges redundant tokens/patches progressively in the forward pass of a Transformer-based network which can speed-up the inference latency of [`StableDiffusionPipeline`].
+Install ToMe from `pip`:
+
+```bash
+pip install tomesd
+```
+
You can use ToMe from the [`tomesd`](https://github.com/dbolya/tomesd) library with the [`apply_patch`](https://github.com/dbolya/tomesd?tab=readme-ov-file#usage) function:
```diff
-from diffusers import StableDiffusionPipeline
-import tomesd
+ from diffusers import StableDiffusionPipeline
+ import torch
+ import tomesd
-pipeline = StableDiffusionPipeline.from_pretrained(
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True,
-).to("cuda")
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True,
+ ).to("cuda")
+ tomesd.apply_patch(pipeline, ratio=0.5)
-image = pipeline("a photo of an astronaut riding a horse on mars").images[0]
+ image = pipeline("a photo of an astronaut riding a horse on mars").images[0]
```
The `apply_patch` function exposes a number of [arguments](https://github.com/dbolya/tomesd#usage) to help strike a balance between pipeline inference speed and the quality of the generated tokens. The most important argument is `ratio` which controls the number of tokens that are merged during the forward pass.
diff --git a/docs/source/en/optimization/torch2.0.md b/docs/source/en/optimization/torch2.0.md
index 1e07b876514f..4775fda0fcf9 100644
--- a/docs/source/en/optimization/torch2.0.md
+++ b/docs/source/en/optimization/torch2.0.md
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Torch 2.0
+# PyTorch 2.0
🤗 Diffusers supports the latest optimizations from [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/) which include:
@@ -48,7 +48,6 @@ In some cases - such as making the pipeline more deterministic or converting it
```diff
import torch
from diffusers import DiffusionPipeline
- from diffusers.models.attention_processor import AttnProcessor
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
+ pipe.unet.set_default_attn_processor()
@@ -110,17 +109,14 @@ for _ in range(3):
### Stable Diffusion image-to-image
-```python
+```python
from diffusers import StableDiffusionImg2ImgPipeline
-import requests
+from diffusers.utils import load_image
import torch
-from PIL import Image
-from io import BytesIO
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
-response = requests.get(url)
-init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = load_image(url)
init_image = init_image.resize((512, 512))
path = "runwayml/stable-diffusion-v1-5"
@@ -143,25 +139,16 @@ for _ in range(3):
### Stable Diffusion inpainting
-```python
+```python
from diffusers import StableDiffusionInpaintPipeline
-import requests
+from diffusers.utils import load_image
import torch
-from PIL import Image
-from io import BytesIO
-
-url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
-
-def download_image(url):
- response = requests.get(url)
- return Image.open(BytesIO(response.content)).convert("RGB")
-
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
-init_image = download_image(img_url).resize((512, 512))
-mask_image = download_image(mask_url).resize((512, 512))
+init_image = load_image(img_url).resize((512, 512))
+mask_image = load_image(mask_url).resize((512, 512))
path = "runwayml/stable-diffusion-inpainting"
@@ -183,17 +170,14 @@ for _ in range(3):
### ControlNet
-```python
+```python
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
-import requests
+from diffusers.utils import load_image
import torch
-from PIL import Image
-from io import BytesIO
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
-response = requests.get(url)
-init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = load_image(url)
init_image = init_image.resize((512, 512))
path = "runwayml/stable-diffusion-v1-5"
@@ -221,26 +205,26 @@ for _ in range(3):
### DeepFloyd IF text-to-image + upscaling
-```python
+```python
from diffusers import DiffusionPipeline
import torch
run_compile = True # Set True / False
-pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-M-v1.0", variant="fp16", text_encoder=None, torch_dtype=torch.float16, use_safetensors=True)
-pipe.to("cuda")
+pipe_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-M-v1.0", variant="fp16", text_encoder=None, torch_dtype=torch.float16, use_safetensors=True)
+pipe_1.to("cuda")
pipe_2 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-M-v1.0", variant="fp16", text_encoder=None, torch_dtype=torch.float16, use_safetensors=True)
pipe_2.to("cuda")
pipe_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, use_safetensors=True)
pipe_3.to("cuda")
-pipe.unet.to(memory_format=torch.channels_last)
+pipe_1.unet.to(memory_format=torch.channels_last)
pipe_2.unet.to(memory_format=torch.channels_last)
pipe_3.unet.to(memory_format=torch.channels_last)
if run_compile:
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+ pipe_1.unet = torch.compile(pipe_1.unet, mode="reduce-overhead", fullgraph=True)
pipe_2.unet = torch.compile(pipe_2.unet, mode="reduce-overhead", fullgraph=True)
pipe_3.unet = torch.compile(pipe_3.unet, mode="reduce-overhead", fullgraph=True)
@@ -250,9 +234,9 @@ prompt_embeds = torch.randn((1, 2, 4096), dtype=torch.float16)
neg_prompt_embeds = torch.randn((1, 2, 4096), dtype=torch.float16)
for _ in range(3):
- image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type="pt").images
- image_2 = pipe_2(image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type="pt").images
- image_3 = pipe_3(prompt=prompt, image=image, noise_level=100).images
+ image_1 = pipe_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type="pt").images
+ image_2 = pipe_2(image=image_1, prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type="pt").images
+ image_3 = pipe_3(prompt=prompt, image=image_1, noise_level=100).images
```
@@ -426,9 +410,9 @@ In the following tables, we report our findings in terms of the *number of itera
| IF | 9.26 | 9.2 | ❌ | 13.31 |
| SDXL - txt2img | 0.52 | 0.53 | - | - |
-## Notes
+## Notes
-* Follow this [PR](https://github.com/huggingface/diffusers/pull/3313) for more details on the environment used for conducting the benchmarks.
+* Follow this [PR](https://github.com/huggingface/diffusers/pull/3313) for more details on the environment used for conducting the benchmarks.
* For the DeepFloyd IF pipeline where batch sizes > 1, we only used a batch size of > 1 in the first IF pipeline for text-to-image generation and NOT for upscaling. That means the two upscaling pipelines received a batch size of 1.
*Thanks to [Horace He](https://github.com/Chillee) from the PyTorch team for their support in improving our support of `torch.compile()` in Diffusers.*
diff --git a/docs/source/en/quicktour.md b/docs/source/en/quicktour.md
index c5ead9829cdc..89792d5c05b3 100644
--- a/docs/source/en/quicktour.md
+++ b/docs/source/en/quicktour.md
@@ -257,7 +257,7 @@ To predict a slightly less noisy image, pass the following to the scheduler's [`
torch.Size([1, 3, 256, 256])
```
-The `less_noisy_sample` can be passed to the next `timestep` where it'll get even less noisy! Let's bring it all together now and visualize the entire denoising process.
+The `less_noisy_sample` can be passed to the next `timestep` where it'll get even less noisy! Let's bring it all together now and visualize the entire denoising process.
First, create a function that postprocesses and displays the denoised image as a `PIL.Image`:
diff --git a/docs/source/en/stable_diffusion.md b/docs/source/en/stable_diffusion.md
index 06eb5bf15f23..c0298eeeb3c1 100644
--- a/docs/source/en/stable_diffusion.md
+++ b/docs/source/en/stable_diffusion.md
@@ -9,12 +9,12 @@ Unless required by applicable law or agreed to in writing, software distributed
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
-
+
# Effective and efficient diffusion
[[open-in-colab]]
-Getting the [`DiffusionPipeline`] to generate images in a certain style or include what you want can be tricky. Often times, you have to run the [`DiffusionPipeline`] several times before you end up with an image you're happy with. But generating something out of nothing is a computationally intensive process, especially if you're running inference over and over again.
+Getting the [`DiffusionPipeline`] to generate images in a certain style or include what you want can be tricky. Often times, you have to run the [`DiffusionPipeline`] several times before you end up with an image you're happy with. But generating something out of nothing is a computationally intensive process, especially if you're running inference over and over again.
This is why it's important to get the most *computational* (speed) and *memory* (GPU vRAM) efficiency from the pipeline to reduce the time between inference cycles so you can iterate faster.
@@ -68,7 +68,7 @@ image
-This process took ~30 seconds on a T4 GPU (it might be faster if your allocated GPU is better than a T4). By default, the [`DiffusionPipeline`] runs inference with full `float32` precision for 50 inference steps. You can speed this up by switching to a lower precision like `float16` or running fewer inference steps.
+This process took ~30 seconds on a T4 GPU (it might be faster if your allocated GPU is better than a T4). By default, the [`DiffusionPipeline`] runs inference with full `float32` precision for 50 inference steps. You can speed this up by switching to a lower precision like `float16` or running fewer inference steps.
Let's start by loading the model in `float16` and generate an image:
diff --git a/docs/source/en/training/lora.md b/docs/source/en/training/lora.md
index 28a9adf3ec61..7c13b7af9d7d 100644
--- a/docs/source/en/training/lora.md
+++ b/docs/source/en/training/lora.md
@@ -113,14 +113,15 @@ Load the LoRA weights from your finetuned model *on top of the base model weight
```py
>>> pipe.unet.load_attn_procs(lora_model_path)
>>> pipe.to("cuda")
-# use half the weights from the LoRA finetuned model and half the weights from the base model
+# use half the weights from the LoRA finetuned model and half the weights from the base model
>>> image = pipe(
... "A pokemon with blue eyes.", num_inference_steps=25, guidance_scale=7.5, cross_attention_kwargs={"scale": 0.5}
... ).images[0]
-# use the weights from the fully finetuned LoRA model
->>> image = pipe("A pokemon with blue eyes.", num_inference_steps=25, guidance_scale=7.5).images[0]
+# OR, use the weights from the fully finetuned LoRA model
+# >>> image = pipe("A pokemon with blue eyes.", num_inference_steps=25, guidance_scale=7.5).images[0]
+
>>> image.save("blue_pokemon.png")
```
@@ -225,17 +226,18 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m
```py
>>> pipe.unet.load_attn_procs(lora_model_path)
>>> pipe.to("cuda")
-# use half the weights from the LoRA finetuned model and half the weights from the base model
+# use half the weights from the LoRA finetuned model and half the weights from the base model
>>> image = pipe(
... "A picture of a sks dog in a bucket.",
... num_inference_steps=25,
... guidance_scale=7.5,
... cross_attention_kwargs={"scale": 0.5},
... ).images[0]
-# use the weights from the fully finetuned LoRA model
->>> image = pipe("A picture of a sks dog in a bucket.", num_inference_steps=25, guidance_scale=7.5).images[0]
+# OR, use the weights from the fully finetuned LoRA model
+# >>> image = pipe("A picture of a sks dog in a bucket.", num_inference_steps=25, guidance_scale=7.5).images[0]
+
>>> image.save("bucket-dog.png")
```
diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md
index 2e3337519caa..da69b712a989 100644
--- a/docs/source/en/tutorials/using_peft_for_inference.md
+++ b/docs/source/en/tutorials/using_peft_for_inference.md
@@ -12,9 +12,9 @@ specific language governing permissions and limitations under the License.
[[open-in-colab]]
-# Inference with PEFT
+# Load LoRAs for inference
-There are many adapters trained in different styles to achieve different effects. You can even combine multiple adapters to create new and unique images. With the 🤗 [PEFT](https://huggingface.co/docs/peft/index) integration in 🤗 Diffusers, it is really easy to load and manage adapters for inference. In this guide, you'll learn how to use different adapters with [Stable Diffusion XL (SDXL)](../api/pipelines/stable_diffusion/stable_diffusion_xl) for inference.
+There are many adapters (with LoRAs being the most common type) trained in different styles to achieve different effects. You can even combine multiple adapters to create new and unique images. With the 🤗 [PEFT](https://huggingface.co/docs/peft/index) integration in 🤗 Diffusers, it is really easy to load and manage adapters for inference. In this guide, you'll learn how to use different adapters with [Stable Diffusion XL (SDXL)](../api/pipelines/stable_diffusion/stable_diffusion_xl) for inference.
Throughout this guide, you'll use LoRA as the main adapter technique, so we'll use the terms LoRA and adapter interchangeably. You should have some familiarity with LoRA, and if you don't, we welcome you to check out the [LoRA guide](https://huggingface.co/docs/peft/conceptual_guides/lora).
@@ -22,9 +22,8 @@ Let's first install all the required libraries.
```bash
!pip install -q transformers accelerate
-# Will be updated once the stable releases are done.
-!pip install -q git+https://github.com/huggingface/peft.git
-!pip install -q git+https://github.com/huggingface/diffusers.git
+!pip install peft
+!pip install diffusers
```
Now, let's load a pipeline with a SDXL checkpoint:
@@ -165,3 +164,22 @@ list_adapters_component_wise = pipe.get_list_adapters()
list_adapters_component_wise
{"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
```
+
+## Fusing adapters into the model
+
+You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~diffusers.loaders.LoraLoaderMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.
+
+```py
+pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
+
+pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
+# Fuses the LoRAs into the Unet
+pipe.fuse_lora()
+
+prompt = "toy_face of a hacker with a hoodie, pixel art"
+image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
+
+# Gets the Unet back to the original state
+pipe.unfuse_lora()
+```
diff --git a/docs/source/en/using-diffusers/callback.md b/docs/source/en/using-diffusers/callback.md
index 2293dc1f60c6..b4f16bda55eb 100644
--- a/docs/source/en/using-diffusers/callback.md
+++ b/docs/source/en/using-diffusers/callback.md
@@ -10,18 +10,18 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Using callback
+# Using callback
[[open-in-colab]]
-Most 🤗 Diffusers pipeline now accept a `callback_on_step_end` argument that allows you to change the default behavior of denoising loop with custom defined functions. Here is an example of a callback function we can write to disable classifier free guidance after 40% of inference steps to save compute with minimum tradeoff in performance.
+Most 🤗 Diffusers pipelines now accept a `callback_on_step_end` argument that allows you to change the default behavior of denoising loop with custom defined functions. Here is an example of a callback function we can write to disable classifier-free guidance after 40% of inference steps to save compute with a minimum tradeoff in performance.
```python
-def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
+def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
# adjust the batch_size of prompt_embeds according to guidance_scale
if step_index == int(pipe.num_timestep * 0.4):
prompt_embeds = callback_kwargs["prompt_embeds"]
- prompt_embeds =prompt_embeds.chunk(2)[-1]
+ prompt_embeds = prompt_embeds.chunk(2)[-1]
# update guidance_scale and prompt_embeds
pipe._guidance_scale = 0.0
@@ -34,9 +34,9 @@ Your callback function has below arguments:
* `step_index` and `timestep` tell you where you are in the denoising loop. In our example, we use `step_index` to decide when to turn off CFG.
* `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables so please check the pipeline class's `_callback_tensor_inputs` attribute for the list of variables that you can modify. Common variables include `latents` and `prompt_embeds`. In our example, we need to adjust the batch size of `prompt_embeds` after setting `guidance_scale` to be `0` in order for it to work properly.
-You can pass the callback function as `callback_on_step_end` argument to the pipeline along with `callback_on_step_end_tensor_inputs`.
+You can pass the callback function as `callback_on_step_end` argument to the pipeline along with `callback_on_step_end_tensor_inputs`.
-```
+```python
import torch
from diffusers import StableDiffusionPipeline
@@ -46,7 +46,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
generator = torch.Generator(device="cuda").manual_seed(1)
-out= pipe(prompt, generator=generator, callback_on_step_end = callback_custom_cfg, callback_on_step_end_tensor_inputs=['prompt_embeds'])
+out = pipe(prompt, generator=generator, callback_on_step_end=callback_custom_cfg, callback_on_step_end_tensor_inputs=['prompt_embeds'])
out.images[0].save("out_custom_cfg.png")
```
@@ -55,6 +55,6 @@ Your callback function will be executed at the end of each denoising step and mo
-Currently we only support `callback_on_step_end`. If you have a solid use case and require a callback function with a different execution point, please open an [feature request](https://github.com/huggingface/diffusers/issues/new/choose) so we can add it!
+Currently we only support `callback_on_step_end`. If you have a solid use case and require a callback function with a different execution point, please open a [Feature Request](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&projects=&template=feature_request.md&title=) so we can add it!
-
\ No newline at end of file
+
diff --git a/docs/source/en/using-diffusers/conditional_image_generation.md b/docs/source/en/using-diffusers/conditional_image_generation.md
index d07658e4f58e..9832f53cffe6 100644
--- a/docs/source/en/using-diffusers/conditional_image_generation.md
+++ b/docs/source/en/using-diffusers/conditional_image_generation.md
@@ -30,6 +30,7 @@ You can generate images from a prompt in 🤗 Diffusers in two steps:
```py
from diffusers import AutoPipelineForText2Image
+import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
@@ -42,6 +43,7 @@ pipeline = AutoPipelineForText2Image.from_pretrained(
image = pipeline(
"stained glass of darth vader, backlight, centered composition, masterpiece, photorealistic, 8k"
).images[0]
+image
```
@@ -65,6 +67,7 @@ pipeline = AutoPipelineForText2Image.from_pretrained(
).to("cuda")
generator = torch.Generator("cuda").manual_seed(31)
image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", generator=generator).images[0]
+image
```
### Stable Diffusion XL
@@ -80,6 +83,7 @@ pipeline = AutoPipelineForText2Image.from_pretrained(
).to("cuda")
generator = torch.Generator("cuda").manual_seed(31)
image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", generator=generator).images[0]
+image
```
### Kandinsky 2.2
@@ -93,15 +97,16 @@ from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
- "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16"
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16
).to("cuda")
generator = torch.Generator("cuda").manual_seed(31)
image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", generator=generator).images[0]
+image
```
### ControlNet
-ControlNet are auxiliary models or adapters that are finetuned on top of text-to-image models, such as [Stable Diffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5). Using ControlNet models in combination with text-to-image models offers diverse options for more explicit control over how to generate an image. With ControlNet's, you add an additional conditioning input image to the model. For example, if you provide an image of a human pose (usually represented as multiple keypoints that are connected into a skeleton) as a conditioning input, the model generates an image that follows the pose of the image. Check out the more in-depth [ControlNet](controlnet) guide to learn more about other conditioning inputs and how to use them.
+ControlNet models are auxiliary models or adapters that are finetuned on top of text-to-image models, such as [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5). Using ControlNet models in combination with text-to-image models offers diverse options for more explicit control over how to generate an image. With ControlNet, you add an additional conditioning input image to the model. For example, if you provide an image of a human pose (usually represented as multiple keypoints that are connected into a skeleton) as a conditioning input, the model generates an image that follows the pose of the image. Check out the more in-depth [ControlNet](controlnet) guide to learn more about other conditioning inputs and how to use them.
In this example, let's condition the ControlNet with a human pose estimation image. Load the ControlNet model pretrained on human pose estimations:
@@ -124,6 +129,7 @@ pipeline = AutoPipelineForText2Image.from_pretrained(
).to("cuda")
generator = torch.Generator("cuda").manual_seed(31)
image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=pose_image, generator=generator).images[0]
+image
```
@@ -163,6 +169,7 @@ pipeline = AutoPipelineForText2Image.from_pretrained(
image = pipeline(
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", height=768, width=512
).images[0]
+image
```
@@ -171,7 +178,7 @@ image = pipeline(
-Other models may have different default image sizes depending on the image size's in the training dataset. For example, SDXL's default image size is 1024x1024 and using lower `height` and `width` values may result in lower quality images. Make sure you check the model's API reference first!
+Other models may have different default image sizes depending on the image sizes in the training dataset. For example, SDXL's default image size is 1024x1024 and using lower `height` and `width` values may result in lower quality images. Make sure you check the model's API reference first!
@@ -189,6 +196,7 @@ pipeline = AutoPipelineForText2Image.from_pretrained(
image = pipeline(
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", guidance_scale=3.5
).images[0]
+image
```
@@ -221,16 +229,17 @@ image = pipeline(
prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
negative_prompt="ugly, deformed, disfigured, poor details, bad anatomy",
).images[0]
+image
```
-
negative prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
+
negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
-
negative prompt = "astronaut"
+
negative_prompt = "astronaut"
@@ -252,6 +261,7 @@ image = pipeline(
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
generator=generator,
).images[0]
+image
```
## Control image generation
@@ -278,14 +288,14 @@ pipeline = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda")
image = pipeline(
- prompt_emebds=prompt_embeds, # generated from Compel
+ prompt_embeds=prompt_embeds, # generated from Compel
negative_prompt_embeds=negative_prompt_embeds, # generated from Compel
).images[0]
```
### ControlNet
-As you saw in the [ControlNet](#controlnet) section, these models offer a more flexible and accurate way to generate images by incorporating an additional conditioning image input. Each ControlNet model is pretrained on a particular type of conditioning image to generate new images that resemble it. For example, if you take a ControlNet pretrained on depth maps, you can give the model a depth map as a conditioning input and it'll generate an image that preserves the spatial information in it. This is quicker and easier than specifying the depth information in a prompt. You can even combine multiple conditioning inputs with a [MultiControlNet](controlnet#multicontrolnet)!
+As you saw in the [ControlNet](#controlnet) section, these models offer a more flexible and accurate way to generate images by incorporating an additional conditioning image input. Each ControlNet model is pretrained on a particular type of conditioning image to generate new images that resemble it. For example, if you take a ControlNet model pretrained on depth maps, you can give the model a depth map as a conditioning input and it'll generate an image that preserves the spatial information in it. This is quicker and easier than specifying the depth information in a prompt. You can even combine multiple conditioning inputs with a [MultiControlNet](controlnet#multicontrolnet)!
There are many types of conditioning inputs you can use, and 🤗 Diffusers supports ControlNet for Stable Diffusion and SDXL models. Take a look at the more comprehensive [ControlNet](controlnet) guide to learn how you can use these models.
@@ -300,7 +310,7 @@ from diffusers import AutoPipelineForText2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16").to("cuda")
-pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overheard", fullgraph=True)
+pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
-For more tips on how to optimize your code to save memory and speed up inference, read the [Memory and speed](../optimization/fp16) and [Torch 2.0](../optimization/torch2.0) guides.
\ No newline at end of file
+For more tips on how to optimize your code to save memory and speed up inference, read the [Memory and speed](../optimization/fp16) and [Torch 2.0](../optimization/torch2.0) guides.
diff --git a/docs/source/en/using-diffusers/depth2img.md b/docs/source/en/using-diffusers/depth2img.md
index 0a6df2258235..84c613b0dade 100644
--- a/docs/source/en/using-diffusers/depth2img.md
+++ b/docs/source/en/using-diffusers/depth2img.md
@@ -20,12 +20,10 @@ Start by creating an instance of the [`StableDiffusionDepth2ImgPipeline`]:
```python
import torch
-import requests
-from PIL import Image
-
from diffusers import StableDiffusionDepth2ImgPipeline
+from diffusers.utils import load_image, make_image_grid
-pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
+pipeline = StableDiffusionDepth2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-depth",
torch_dtype=torch.float16,
use_safetensors=True,
@@ -36,22 +34,13 @@ Now pass your prompt to the pipeline. You can also pass a `negative_prompt` to p
```python
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
-init_image = Image.open(requests.get(url, stream=True).raw)
+init_image = load_image(url)
prompt = "two tigers"
-n_prompt = "bad, deformed, ugly, bad anatomy"
-image = pipe(prompt=prompt, image=init_image, negative_prompt=n_prompt, strength=0.7).images[0]
-image
+negative_prompt = "bad, deformed, ugly, bad anatomy"
+image = pipeline(prompt=prompt, image=init_image, negative_prompt=negative_prompt, strength=0.7).images[0]
+make_image_grid([init_image, image], rows=1, cols=2)
```
| Input | Output |
|---------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|
|
|
|
-
-Play around with the Spaces below and see if you notice a difference between generated images with and without a depth map!
-
-
diff --git a/docs/source/en/using-diffusers/img2img.md b/docs/source/en/using-diffusers/img2img.md
index 53d7c46b79c8..5caba021f39e 100644
--- a/docs/source/en/using-diffusers/img2img.md
+++ b/docs/source/en/using-diffusers/img2img.md
@@ -21,13 +21,15 @@ With 🤗 Diffusers, this is as easy as 1-2-3:
1. Load a checkpoint into the [`AutoPipelineForImage2Image`] class; this pipeline automatically handles loading the correct pipeline class based on the checkpoint:
```py
+import torch
from diffusers import AutoPipelineForImage2Image
-from diffusers.utils import load_image
+from diffusers.utils import load_image, make_image_grid
pipeline = AutoPipelineForImage2Image.from_pretrained(
- "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
```
@@ -48,7 +50,7 @@ init_image = load_image("https://huggingface.co/datasets/huggingface/documentati
```py
prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
image = pipeline(prompt, image=init_image).images[0]
-image
+make_image_grid([init_image, image], rows=1, cols=2)
```
@@ -72,27 +74,25 @@ Stable Diffusion v1.5 is a latent diffusion model initialized from an earlier ch
```py
import torch
-import requests
-from PIL import Image
-from io import BytesIO
from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-response = requests.get(url)
-init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = load_image(url)
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image).images[0]
-image
+make_image_grid([init_image, image], rows=1, cols=2)
```
@@ -112,27 +112,25 @@ SDXL is a more powerful version of the Stable Diffusion model. It uses a larger
```py
import torch
-import requests
-from PIL import Image
-from io import BytesIO
from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-sdxl-init.png"
-response = requests.get(url)
-init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = load_image(url)
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image, strength=0.5).images[0]
-image
+make_image_grid([init_image, image], rows=1, cols=2)
```
@@ -154,27 +152,25 @@ The simplest way to use Kandinsky 2.2 is:
```py
import torch
-import requests
-from PIL import Image
-from io import BytesIO
from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
- "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-response = requests.get(url)
-init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = load_image(url)
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image).images[0]
-image
+make_image_grid([init_image, image], rows=1, cols=2)
```
@@ -199,32 +195,29 @@ There are several important parameters you can configure in the pipeline that'll
- 📈 a higher `strength` value gives the model more "creativity" to generate an image that's different from the initial image; a `strength` value of 1.0 means the initial image is more or less ignored
- 📉 a lower `strength` value means the generated image is more similar to the initial image
-The `strength` and `num_inference_steps` parameter are related because `strength` determines the number of noise steps to add. For example, if the `num_inference_steps` is 50 and `strength` is 0.8, then this means adding 40 (50 * 0.8) steps of noise to the initial image and then denoising for 40 steps to get the newly generated image.
+The `strength` and `num_inference_steps` parameters are related because `strength` determines the number of noise steps to add. For example, if the `num_inference_steps` is 50 and `strength` is 0.8, then this means adding 40 (50 * 0.8) steps of noise to the initial image and then denoising for 40 steps to get the newly generated image.
```py
import torch
-import requests
-from PIL import Image
-from io import BytesIO
from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-response = requests.get(url)
-init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = load_image(url)
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-image = init_image
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image, strength=0.8).images[0]
-image
+make_image_grid([init_image, image], rows=1, cols=2)
```
@@ -250,27 +243,25 @@ You can combine `guidance_scale` with `strength` for even more precise control o
```py
import torch
-import requests
-from PIL import Image
-from io import BytesIO
from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-response = requests.get(url)
-init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = load_image(url)
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
# pass prompt and image to pipeline
image = pipeline(prompt, image=init_image, guidance_scale=8.0).images[0]
-image
+make_image_grid([init_image, image], rows=1, cols=2)
```
@@ -294,38 +285,36 @@ A negative prompt conditions the model to *not* include things in an image, and
```py
import torch
-import requests
-from PIL import Image
-from io import BytesIO
from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-response = requests.get(url)
-init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = load_image(url)
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
# pass prompt and image to pipeline
image = pipeline(prompt, negative_prompt=negative_prompt, image=init_image).images[0]
-image
+make_image_grid([init_image, image], rows=1, cols=2)
```
-
negative prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
+
negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
-
negative prompt = "jungle"
+
negative_prompt = "jungle"
@@ -342,52 +331,54 @@ Start by generating an image with the text-to-image pipeline:
```py
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
import torch
+from diffusers.utils import make_image_grid
pipeline = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
-image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0]
+text2image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0]
+text2image
```
Now you can pass this generated image to the image-to-image pipeline:
```py
pipeline = AutoPipelineForImage2Image.from_pretrained(
- "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+ "kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
-image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=image).images[0]
-image
+image2image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=text2image).images[0]
+make_image_grid([text2image, image2image], rows=1, cols=2)
```
### Image-to-image-to-image
-You can also chain multiple image-to-image pipelines together to create more interesting images. This can be useful for iteratively performing style transfer on an image, generate short GIFs, restore color to an image, or restore missing areas of an image.
+You can also chain multiple image-to-image pipelines together to create more interesting images. This can be useful for iteratively performing style transfer on an image, generating short GIFs, restoring color to an image, or restoring missing areas of an image.
Start by generating an image:
```py
import torch
-import requests
-from PIL import Image
-from io import BytesIO
from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-response = requests.get(url)
-init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = load_image(url)
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
@@ -404,10 +395,11 @@ It is important to specify `output_type="latent"` in the pipeline to keep all th
Pass the latent output from this pipeline to the next pipeline to generate an image in a [comic book art style](https://huggingface.co/ogkalu/Comic-Diffusion):
```py
-pipelne = AutoPipelineForImage2Image.from_pretrained(
- "ogkalu/Comic-Diffusion", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+pipeline = AutoPipelineForImage2Image.from_pretrained(
+ "ogkalu/Comic-Diffusion", torch_dtype=torch.float16
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# need to include the token "charliebo artstyle" in the prompt to use this checkpoint
@@ -418,14 +410,15 @@ Repeat one more time to generate the final image in a [pixel art style](https://
```py
pipeline = AutoPipelineForImage2Image.from_pretrained(
- "kohbanye/pixel-art-style", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
+ "kohbanye/pixel-art-style", torch_dtype=torch.float16
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# need to include the token "pixelartstyle" in the prompt to use this checkpoint
image = pipeline("Astronaut in a jungle, pixelartstyle", image=image).images[0]
-image
+make_image_grid([init_image, image], rows=1, cols=2)
```
### Image-to-upscaler-to-super-resolution
@@ -436,21 +429,19 @@ Start with an image-to-image pipeline:
```py
import torch
-import requests
-from PIL import Image
-from io import BytesIO
from diffusers import AutoPipelineForImage2Image
+from diffusers.utils import make_image_grid, load_image
pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-response = requests.get(url)
-init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = load_image(url)
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
@@ -467,7 +458,9 @@ It is important to specify `output_type="latent"` in the pipeline to keep all th
Chain it to an upscaler pipeline to increase the image resolution:
```py
-upscaler = AutoPipelineForImage2Image.from_pretrained(
+from diffusers import StableDiffusionLatentUpscalePipeline
+
+upscaler = StableDiffusionLatentUpscalePipeline.from_pretrained(
"stabilityai/sd-x2-latent-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
upscaler.enable_model_cpu_offload()
@@ -479,14 +472,16 @@ image_2 = upscaler(prompt, image=image_1, output_type="latent").images[0]
Finally, chain it to a super-resolution pipeline to further enhance the resolution:
```py
-super_res = AutoPipelineForImage2Image.from_pretrained(
+from diffusers import StableDiffusionUpscalePipeline
+
+super_res = StableDiffusionUpscalePipeline.from_pretrained(
"stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
super_res.enable_model_cpu_offload()
super_res.enable_xformers_memory_efficient_attention()
-image_3 = upscaler(prompt, image=image_2).images[0]
-image_3
+image_3 = super_res(prompt, image=image_2).images[0]
+make_image_grid([init_image, image_3.resize((512, 512))], rows=1, cols=2)
```
## Control image generation
@@ -504,13 +499,14 @@ from diffusers import AutoPipelineForImage2Image
import torch
pipeline = AutoPipelineForImage2Image.from_pretrained(
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
-image = pipeline(prompt_emebds=prompt_embeds, # generated from Compel
- negative_prompt_embeds, # generated from Compel
+image = pipeline(prompt_embeds=prompt_embeds, # generated from Compel
+ negative_prompt_embeds=negative_prompt_embeds, # generated from Compel
image=init_image,
).images[0]
```
@@ -522,19 +518,20 @@ ControlNets provide a more flexible and accurate way to control image generation
For example, let's condition an image with a depth map to keep the spatial information in the image.
```py
+from diffusers.utils import load_image, make_image_grid
+
# prepare image
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
-response = requests.get(url)
-init_image = Image.open(BytesIO(response.content)).convert("RGB")
+init_image = load_image(url)
init_image = init_image.resize((958, 960)) # resize to depth image dimensions
depth_image = load_image("https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/images/control.png")
+make_image_grid([init_image, depth_image], rows=1, cols=2)
```
Load a ControlNet model conditioned on depth maps and the [`AutoPipelineForImage2Image`]:
```py
from diffusers import ControlNetModel, AutoPipelineForImage2Image
-from diffusers.utils import load_image
import torch
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11f1p_sd15_depth", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
@@ -542,6 +539,7 @@ pipeline = AutoPipelineForImage2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
```
@@ -549,8 +547,8 @@ Now generate a new image conditioned on the depth map, initial image, and prompt
```py
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-image = pipeline(prompt, image=init_image, control_image=depth_image).images[0]
-image
+image_control_net = pipeline(prompt, image=init_image, control_image=depth_image).images[0]
+make_image_grid([init_image, depth_image, image_control_net], rows=1, cols=3)
```
@@ -575,13 +573,14 @@ pipeline = AutoPipelineForImage2Image.from_pretrained(
"nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
prompt = "elden ring style astronaut in a jungle" # include the token "elden ring style" in the prompt
negative_prompt = "ugly, deformed, disfigured, poor details, bad anatomy"
-image = pipeline(prompt, negative_prompt=negative_prompt, image=init_image, strength=0.45, guidance_scale=10.5).images[0]
-image
+image_elden_ring = pipeline(prompt, negative_prompt=negative_prompt, image=image_control_net, strength=0.45, guidance_scale=10.5).images[0]
+make_image_grid([init_image, depth_image, image_control_net, image_elden_ring], rows=2, cols=2)
```
@@ -597,10 +596,10 @@ Running diffusion models is computationally expensive and intensive, but with a
+ pipeline.enable_xformers_memory_efficient_attention()
```
-With [`torch.compile`](../optimization/torch2.0#torch.compile), you can boost your inference speed even more by wrapping your UNet with it:
+With [`torch.compile`](../optimization/torch2.0#torchcompile), you can boost your inference speed even more by wrapping your UNet with it:
```py
-pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
To learn more, take a look at the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
diff --git a/docs/source/en/using-diffusers/inpaint.md b/docs/source/en/using-diffusers/inpaint.md
index 42bfb8984d9e..3d03d4e0e4d0 100644
--- a/docs/source/en/using-diffusers/inpaint.md
+++ b/docs/source/en/using-diffusers/inpaint.md
@@ -23,12 +23,13 @@ With 🤗 Diffusers, here is how you can do inpainting:
```py
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image
+from diffusers.utils import load_image, make_image_grid
pipeline = AutoPipelineForInpainting.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
```
@@ -41,8 +42,8 @@ You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu
2. Load the base and mask images:
```py
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
```
3. Create a prompt to inpaint the image with and pass it to the pipeline with the base and mask images:
@@ -51,6 +52,7 @@ mask_image = load_image("https://huggingface.co/datasets/huggingface/documentati
prompt = "a black cat with glowing eyes, cute, adorable, disney, pixar, highly detailed, 8k"
negative_prompt = "bad anatomy, deformed, ugly, disfigured"
image = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask_image).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
@@ -58,6 +60,10 @@ image = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=init_imag
base image
+
+
+
mask image
+
generated image
@@ -79,7 +85,7 @@ Upload a base image to inpaint on and use the sketch tool to draw a mask. Once y
## Popular models
-[Stable Diffusion Inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting), [Stable Diffusion XL (SDXL) Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder-inpaint) are among the most popular models for inpainting. SDXL typically produces higher resolution images than Stable Diffusion v1.5, and Kandinsky 2.2 is also capable of generating high-quality images.
+[Stable Diffusion Inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting), [Stable Diffusion XL (SDXL) Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1), and [Kandinsky 2.2 Inpainting](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder-inpaint) are among the most popular models for inpainting. SDXL typically produces higher resolution images than Stable Diffusion v1.5, and Kandinsky 2.2 is also capable of generating high-quality images.
### Stable Diffusion Inpainting
@@ -88,21 +94,23 @@ Stable Diffusion Inpainting is a latent diffusion model finetuned on 512x512 ima
```py
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image
+from diffusers.utils import load_image, make_image_grid
pipeline = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
generator = torch.Generator("cuda").manual_seed(92)
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
### Stable Diffusion XL (SDXL) Inpainting
@@ -112,21 +120,23 @@ SDXL is a larger and more powerful version of Stable Diffusion v1.5. This model
```py
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image
+from diffusers.utils import load_image, make_image_grid
pipeline = AutoPipelineForInpainting.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16"
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
generator = torch.Generator("cuda").manual_seed(92)
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
### Kandinsky 2.2 Inpainting
@@ -136,21 +146,23 @@ The Kandinsky model family is similar to SDXL because it uses two models as well
```py
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image
+from diffusers.utils import load_image, make_image_grid
pipeline = AutoPipelineForInpainting.from_pretrained(
"kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
generator = torch.Generator("cuda").manual_seed(92)
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
@@ -186,20 +198,22 @@ Image features - like quality and "creativity" - are dependent on pipeline param
```py
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image
+from diffusers.utils import load_image, make_image_grid
pipeline = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.6).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
@@ -229,20 +243,22 @@ You can use `strength` and `guidance_scale` together for more control over how e
```py
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image
+from diffusers.utils import load_image, make_image_grid
pipeline = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, guidance_scale=2.5).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
@@ -267,22 +283,23 @@ A negative prompt assumes the opposite role of a prompt; it guides the model awa
```py
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image
+from diffusers.utils import load_image, make_image_grid
pipeline = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
negative_prompt = "bad architecture, unstable, poor details, blurry"
image = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask_image).images[0]
-image
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
@@ -302,7 +319,7 @@ import numpy as np
import torch
from diffusers import AutoPipelineForInpainting
-from diffusers.utils import load_image
+from diffusers.utils import load_image, make_image_grid
device = "cuda"
pipeline = AutoPipelineForInpainting.from_pretrained(
@@ -334,6 +351,7 @@ mask_image_arr[mask_image_arr >= 0.5] = 1
unmasked_unchanged_image_arr = (1 - mask_image_arr) * init_image + mask_image_arr * repainted_image
unmasked_unchanged_image = PIL.Image.fromarray(unmasked_unchanged_image_arr.round().astype("uint8"))
unmasked_unchanged_image.save("force_unmasked_unchanged.png")
+make_image_grid([init_image, mask_image, repainted_image, unmasked_unchanged_image], rows=2, cols=2)
```
## Chained inpainting pipelines
@@ -349,35 +367,37 @@ Start with the text-to-image pipeline to create a castle:
```py
import torch
from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting
-from diffusers.utils import load_image
+from diffusers.utils import load_image, make_image_grid
pipeline = AutoPipelineForText2Image.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
-image = pipeline("concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k").images[0]
+text2image = pipeline("concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k").images[0]
```
Load the mask image of the output from above:
```py
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_text-chain-mask.png").convert("RGB")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_text-chain-mask.png")
```
And let's inpaint the masked area with a waterfall:
```py
pipeline = AutoPipelineForInpainting.from_pretrained(
- "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16, variant="fp16"
+ "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
prompt = "digital painting of a fantasy waterfall, cloudy"
-image = pipeline(prompt=prompt, image=image, mask_image=mask_image).images[0]
-image
+image = pipeline(prompt=prompt, image=text2image, mask_image=mask_image).images[0]
+make_image_grid([text2image, mask_image, image], rows=1, cols=3)
```
@@ -391,7 +411,6 @@ image
-
### Inpaint-to-image-to-image
You can also chain an inpainting pipeline before another pipeline like image-to-image or an upscaler to improve the quality.
@@ -401,23 +420,24 @@ Begin by inpainting an image:
```py
import torch
from diffusers import AutoPipelineForInpainting, AutoPipelineForImage2Image
-from diffusers.utils import load_image
+from diffusers.utils import load_image, make_image_grid
pipeline = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
-image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+image_inpainting = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
# resize image to 1024x1024 for SDXL
-image = image.resize((1024, 1024))
+image_inpainting = image_inpainting.resize((1024, 1024))
```
Now let's pass the image to another inpainting pipeline with SDXL's refiner model to enhance the image details and quality:
@@ -427,9 +447,10 @@ pipeline = AutoPipelineForInpainting.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16"
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
-image = pipeline(prompt=prompt, image=image, mask_image=mask_image, output_type="latent").images[0]
+image = pipeline(prompt=prompt, image=image_inpainting, mask_image=mask_image, output_type="latent").images[0]
```
@@ -442,9 +463,11 @@ Finally, you can pass this image to an image-to-image pipeline to put the finish
```py
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline)
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
image = pipeline(prompt=prompt, image=image).images[0]
+make_image_grid([init_image, mask_image, image_inpainting, image], rows=2, cols=2)
```
@@ -477,18 +500,21 @@ Once you've generated the embeddings, pass them to the `prompt_embeds` (and `neg
```py
import torch
from diffusers import AutoPipelineForInpainting
+from diffusers.utils import make_image_grid
pipeline = AutoPipelineForInpainting.from_pretrained(
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
-image = pipeline(prompt_emebds=prompt_embeds, # generated from Compel
- negative_prompt_embeds, # generated from Compel
+image = pipeline(prompt_embeds=prompt_embeds, # generated from Compel
+ negative_prompt_embeds=negative_prompt_embeds, # generated from Compel
image=init_image,
mask_image=mask_image
).images[0]
+make_image_grid([init_image, mask_image, image], rows=1, cols=3)
```
### ControlNet
@@ -501,7 +527,7 @@ For example, let's condition an image with a ControlNet pretrained on inpaint im
import torch
import numpy as np
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
-from diffusers.utils import load_image
+from diffusers.utils import load_image, make_image_grid
# load ControlNet
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16, variant="fp16")
@@ -511,11 +537,12 @@ pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16"
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
# load base and mask image
-init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
-mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
+mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
# prepare control image
def make_inpaint_condition(init_image, mask_image):
@@ -536,7 +563,7 @@ Now generate an image from the base, mask and control images. You'll notice feat
```py
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, control_image=control_image).images[0]
-image
+make_image_grid([init_image, mask_image, PIL.Image.fromarray(np.uint8(control_image[0][0])).convert('RGB'), image], rows=2, cols=2)
```
You can take this a step further and chain it with an image-to-image pipeline to apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion):
@@ -548,13 +575,14 @@ pipeline = AutoPipelineForImage2Image.from_pretrained(
"nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_model_cpu_offload()
+# remove following line if xFormers is not installed or you have PyTorch 2.0 or higher installed
pipeline.enable_xformers_memory_efficient_attention()
prompt = "elden ring style castle" # include the token "elden ring style" in the prompt
negative_prompt = "bad architecture, deformed, disfigured, poor details"
-image = pipeline(prompt, negative_prompt=negative_prompt, image=image).images[0]
-image
+image_elden_ring = pipeline(prompt, negative_prompt=negative_prompt, image=image).images[0]
+make_image_grid([init_image, mask_image, image, image_elden_ring], rows=2, cols=2)
```
@@ -576,17 +604,17 @@ image
It can be difficult and slow to run diffusion models if you're resource constrained, but it doesn't have to be with a few optimization tricks. One of the biggest (and easiest) optimizations you can enable is switching to memory-efficient attention. If you're using PyTorch 2.0, [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) is automatically enabled and you don't need to do anything else. For non-PyTorch 2.0 users, you can install and use [xFormers](../optimization/xformers)'s implementation of memory-efficient attention. Both options reduce memory usage and accelerate inference.
-You can also offload the model to the GPU to save even more memory:
+You can also offload the model to the CPU to save even more memory:
```diff
+ pipeline.enable_xformers_memory_efficient_attention()
+ pipeline.enable_model_cpu_offload()
```
-To speed-up your inference code even more, use [`torch_compile`](../optimization/torch2.0#torch.compile). You should wrap `torch.compile` around the most intensive component in the pipeline which is typically the UNet:
+To speed-up your inference code even more, use [`torch_compile`](../optimization/torch2.0#torchcompile). You should wrap `torch.compile` around the most intensive component in the pipeline which is typically the UNet:
```py
-pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
-Learn more in the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
\ No newline at end of file
+Learn more in the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
diff --git a/docs/source/en/using-diffusers/lcm.md b/docs/source/en/using-diffusers/lcm.md
new file mode 100644
index 000000000000..39bc2426a92b
--- /dev/null
+++ b/docs/source/en/using-diffusers/lcm.md
@@ -0,0 +1,154 @@
+
+
+# Performing inference with LCM
+
+Latent Consistency Models (LCM) enable quality image generation in typically 2-4 steps making it possible to use diffusion models in almost real-time settings.
+
+From the [official website](https://latent-consistency-models.github.io/):
+
+> LCMs can be distilled from any pre-trained Stable Diffusion (SD) in only 4,000 training steps (~32 A100 GPU Hours) for generating high quality 768 x 768 resolution images in 2~4 steps or even one step, significantly accelerating text-to-image generation. We employ LCM to distill the Dreamshaper-V7 version of SD in just 4,000 training iterations.
+
+For a more technical overview of LCMs, refer to [the paper](https://huggingface.co/papers/2310.04378).
+
+This guide shows how to perform inference with LCMs for text-to-image and image-to-image generation tasks. It will also cover performing inference with LoRA checkpoints.
+
+## Text-to-image
+
+You'll use the [`StableDiffusionXLPipeline`] here changing the `unet`. The UNet was distilled from the SDXL UNet using the framework introduced in LCM. Another important component is the scheduler: [`LCMScheduler`]. Together with the distilled UNet and the scheduler, LCM enables a fast inference workflow overcoming the slow iterative nature of diffusion models.
+
+```python
+from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
+import torch
+
+unet = UNet2DConditionModel.from_pretrained(
+ "latent-consistency/lcm-sdxl",
+ torch_dtype=torch.float16,
+ variant="fp16",
+)
+pipe = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16
+).to("cuda")
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+prompt = "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
+
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt=prompt, num_inference_steps=4, generator=generator, guidance_scale=8.0
+).images[0]
+```
+
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_intro.png)
+
+Notice that we use only 4 steps for generation which is way less than what's typically used for standard SDXL.
+
+Some details to keep in mind:
+
+* To perform classifier-free guidance, batch size is usually doubled inside the pipeline. LCM, however, applies guidance using guidance embeddings, so the batch size does not have to be doubled in this case. This leads to a faster inference time, with the drawback that negative prompts don't have any effect on the denoising process.
+* The UNet was trained using the [3., 13.] guidance scale range. So, that is the ideal range for `guidance_scale`. However, disabling `guidance_scale` using a value of 1.0 is also effective in most cases.
+
+## Image-to-image
+
+The findings above apply to image-to-image tasks too. Let's look at how we can perform image-to-image generation with LCMs:
+
+```python
+from diffusers import AutoPipelineForImage2Image, UNet2DConditionModel, LCMScheduler
+from diffusers.utils import load_image
+import torch
+
+unet = UNet2DConditionModel.from_pretrained(
+ "latent-consistency/lcm-sdxl",
+ torch_dtype=torch.float16,
+ variant="fp16",
+)
+pipe = AutoPipelineForImage2Image.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16
+).to("cuda")
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+prompt = "High altitude snowy mountains"
+image = load_image(
+ "https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/snowy_mountains.jpeg"
+)
+
+generator = torch.manual_seed(0)
+image = pipe(
+ prompt=prompt,
+ image=image,
+ num_inference_steps=4,
+ generator=generator,
+ guidance_scale=8.0,
+).images[0]
+```
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_i2i.png)
+
+## LoRA
+
+It is possible to generalize the LCM framework to use with [LoRA](../training/lora.md). It effectively eliminates the need to conduct expensive fine-tuning runs as LoRA training concerns just a few number of parameters compared to full fine-tuning. During inference, the [`LCMScheduler`] comes to the advantage as it enables very few-steps inference without compromising the quality.
+
+We recommend to disable `guidance_scale` by setting it 0. The model is trained to follow prompts accurately
+even without using guidance scale. You can however, still use guidance scale in which case we recommend
+using values between 1.0 and 2.0.
+
+### Text-to-image
+
+```python
+from diffusers import DiffusionPipeline, LCMScheduler
+import torch
+
+model_id = "stabilityai/stable-diffusion-xl-base-1.0"
+lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
+
+pipe = DiffusionPipeline.from_pretrained(model_id, variant="fp16", torch_dtype=torch.float16).to("cuda")
+
+pipe.load_lora_weights(lcm_lora_id)
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+prompt = "close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm summilux"
+image = pipe(
+ prompt=prompt,
+ num_inference_steps=4,
+ guidance_scale=0, # set guidance scale to 0 to disable it
+).images[0]
+```
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lora_lcm.png)
+
+### Image-to-image
+
+Extending LCM LoRA to image-to-image is possible:
+
+```python
+from diffusers import StableDiffusionXLImg2ImgPipeline, LCMScheduler
+from diffusers.utils import load_image
+import torch
+
+model_id = "stabilityai/stable-diffusion-xl-base-1.0"
+lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
+
+pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(model_id, variant="fp16", torch_dtype=torch.float16).to("cuda")
+
+pipe.load_lora_weights(lcm_lora_id)
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+prompt = "close-up photography of old man standing in the rain at night, in a street lit by lamps, leica 35mm summilux"
+
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lora_lcm.png")
+
+image = pipe(
+ prompt=prompt,
+ image=image,
+ num_inference_steps=4,
+ guidance_scale=0, # set guidance scale to 0 to disable it
+).images[0]
+```
+![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lcm/lcm_lora_i2i.png)
diff --git a/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md b/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md
index 9cf82907180c..6f75ba2c3999 100644
--- a/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md
+++ b/docs/source/en/using-diffusers/stable_diffusion_jax_how_to.md
@@ -38,25 +38,20 @@ device_type = jax.devices()[0].device_kind
print(f"Found {num_devices} JAX devices of type {device_type}.")
assert (
- "TPU" in device_type,
- "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
+ "TPU" in device_type,
+ "Available device is not a TPU, please select TPU from Runtime > Change runtime type > Hardware accelerator"
)
-"Found 8 JAX devices of type Cloud TPU."
+# Found 8 JAX devices of type Cloud TPU.
```
Great, now you can import the rest of the dependencies you'll need:
```python
-import numpy as np
import jax.numpy as jnp
-
-from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
-from PIL import Image
-from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline
```
@@ -90,7 +85,7 @@ prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, por
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
-"(8, 77)"
+# (8, 77)
```
Model parameters and inputs have to be replicated across the 8 parallel devices. The parameters dictionary is replicated with [`flax.jax_utils.replicate`](https://flax.readthedocs.io/en/latest/api_reference/flax.jax_utils.html#flax.jax_utils.replicate) which traverses the dictionary and changes the shape of the weights so they are repeated 8 times. Arrays are replicated using `shard`.
@@ -102,7 +97,7 @@ p_params = replicate(params)
# arrays
prompt_ids = shard(prompt_ids)
prompt_ids.shape
-"(8, 1, 77)"
+# (8, 1, 77)
```
This shape means each one of the 8 devices receives as an input a `jnp` array with shape `(1, 77)`, where `1` is the batch size per device. On TPUs with sufficient memory, you could have a batch size larger than `1` if you want to generate multiple images (per chip) at once.
@@ -127,7 +122,7 @@ To take advantage of JAX's optimized speed on a TPU, pass `jit=True` to the pipe
-You need to ensure all your inputs have the same shape in subsequent calls, other JAX will need to recompile the code which is slower.
+You need to ensure all your inputs have the same shape in subsequent calls, otherwise JAX will need to recompile the code which is slower.
@@ -137,18 +132,18 @@ The first inference run takes more time because it needs to compile the code, bu
%%time
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
-"CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s"
-"Wall time: 1min 29s"
+# CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
+# Wall time: 1min 29s
```
The returned array has shape `(8, 1, 512, 512, 3)` which should be reshaped to remove the second dimension and get 8 images of `512 × 512 × 3`. Then you can use the [`~utils.numpy_to_pil`] function to convert the arrays into images.
```python
-from diffusers import make_image_grid
+from diffusers.utils import make_image_grid
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
-make_image_grid(images, 2, 4)
+make_image_grid(images, rows=2, cols=4)
```
![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_38_output_0.jpeg)
@@ -181,7 +176,6 @@ make_image_grid(images, 2, 4)
![img](https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/stable_diffusion_jax_how_to_cell_43_output_0.jpeg)
-
## How does parallelization work?
The Flax pipeline in 🤗 Diffusers automatically compiles the model and runs it in parallel on all available devices. Let's take a closer look at how that process works.
@@ -202,7 +196,7 @@ p_generate = pmap(pipeline._generate)
After calling `pmap`, the prepared function `p_generate` will:
1. Make a copy of the underlying function, `pipeline._generate`, on each device.
-2. Send each device a different portion of the input arguments (this is why its necessary to call the *shard* function). In this case, `prompt_ids` has shape `(8, 1, 77, 768)` so the array is split into 8 and each copy of `_generate` receives an input with shape `(1, 77, 768)`.
+2. Send each device a different portion of the input arguments (this is why it's necessary to call the *shard* function). In this case, `prompt_ids` has shape `(8, 1, 77, 768)` so the array is split into 8 and each copy of `_generate` receives an input with shape `(1, 77, 768)`.
The most important thing to pay attention to here is the batch size (1 in this example), and the input dimensions that make sense for your code. You don't have to change anything else to make the code work in parallel.
@@ -212,13 +206,14 @@ The first time you call the pipeline takes more time, but the calls afterward ar
%%time
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
-"CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s"
-"Wall time: 1min 15s"
+
+# CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
+# Wall time: 1min 15s
```
Check your image dimensions to see if they're correct:
```python
images.shape
-"(8, 1, 512, 512, 3)"
-```
\ No newline at end of file
+# (8, 1, 512, 512, 3)
+```
diff --git a/docs/source/en/using-diffusers/unconditional_image_generation.md b/docs/source/en/using-diffusers/unconditional_image_generation.md
index 3893f7cce276..c055bc75c5a4 100644
--- a/docs/source/en/using-diffusers/unconditional_image_generation.md
+++ b/docs/source/en/using-diffusers/unconditional_image_generation.md
@@ -23,16 +23,16 @@ You can use any of the 🧨 Diffusers [checkpoints](https://huggingface.co/model
-💡 Want to train your own unconditional image generation model? Take a look at the training [guide](training/unconditional_training) to learn how to generate your own images.
+💡 Want to train your own unconditional image generation model? Take a look at the training [guide](../training/unconditional_training) to learn how to generate your own images.
In this guide, you'll use [`DiffusionPipeline`] for unconditional image generation with [DDPM](https://arxiv.org/abs/2006.11239):
```python
->>> from diffusers import DiffusionPipeline
+from diffusers import DiffusionPipeline
->>> generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128", use_safetensors=True)
+generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128", use_safetensors=True)
```
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
@@ -40,13 +40,14 @@ Because the model consists of roughly 1.4 billion parameters, we strongly recomm
You can move the generator object to a GPU, just like you would in PyTorch:
```python
->>> generator.to("cuda")
+generator.to("cuda")
```
Now you can use the `generator` to generate an image:
```python
->>> image = generator().images[0]
+image = generator().images[0]
+image
```
The output is by default wrapped into a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object.
@@ -54,7 +55,7 @@ The output is by default wrapped into a [`PIL.Image`](https://pillow.readthedocs
You can save the image by calling:
```python
->>> image.save("generated_image.png")
+image.save("generated_image.png")
```
Try out the Spaces below, and feel free to play around with the inference steps parameter to see how it affects the image quality!
@@ -65,5 +66,3 @@ Try out the Spaces below, and feel free to play around with the inference steps
width="850"
height="500"
>
-
-
diff --git a/examples/consistency_distillation/README.md b/examples/consistency_distillation/README.md
new file mode 100644
index 000000000000..c584736dfe82
--- /dev/null
+++ b/examples/consistency_distillation/README.md
@@ -0,0 +1,104 @@
+# Latent Consistency Distillation Example:
+
+[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill stable-diffusion-v1.5 for less timestep inference.
+
+## Full model distillation
+
+### Running locally with PyTorch
+
+#### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+
+#### Example with LAION-A6+ dataset
+
+```bash
+runwayml/stable-diffusion-v1-5
+PROGRAM="train_lcm_distill_sd_wds.py \
+ --pretrained_teacher_model=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision=fp16 \
+ --resolution=512 \
+ --learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \
+ --max_train_steps=1000 \
+ --max_train_samples=4000000 \
+ --dataloader_num_workers=8 \
+ --train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
+ --validation_steps=200 \
+ --checkpointing_steps=200 --checkpoints_total_limit=10 \
+ --train_batch_size=12 \
+ --gradient_checkpointing --enable_xformers_memory_efficient_attention \
+ --gradient_accumulation_steps=1 \
+ --use_8bit_adam \
+ --resume_from_checkpoint=latest \
+ --report_to=wandb \
+ --seed=453645634 \
+ --push_to_hub \
+```
+
+## LCM-LoRA
+
+Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
+
+### Example with LAION-A6+ dataset
+
+```bash
+runwayml/stable-diffusion-v1-5
+PROGRAM="train_lcm_distill_lora_sd_wds.py \
+ --pretrained_teacher_model=$MODEL_DIR \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision=fp16 \
+ --resolution=512 \
+ --lora_rank=64 \
+ --learning_rate=1e-6 --loss_type="huber" --adam_weight_decay=0.0 \
+ --max_train_steps=1000 \
+ --max_train_samples=4000000 \
+ --dataloader_num_workers=8 \
+ --train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
+ --validation_steps=200 \
+ --checkpointing_steps=200 --checkpoints_total_limit=10 \
+ --train_batch_size=12 \
+ --gradient_checkpointing --enable_xformers_memory_efficient_attention \
+ --gradient_accumulation_steps=1 \
+ --use_8bit_adam \
+ --resume_from_checkpoint=latest \
+ --report_to=wandb \
+ --seed=453645634 \
+ --push_to_hub \
+```
\ No newline at end of file
diff --git a/examples/consistency_distillation/README_sdxl.md b/examples/consistency_distillation/README_sdxl.md
new file mode 100644
index 000000000000..00577f9fa2b8
--- /dev/null
+++ b/examples/consistency_distillation/README_sdxl.md
@@ -0,0 +1,106 @@
+# Latent Consistency Distillation Example:
+
+[Latent Consistency Models (LCMs)](https://arxiv.org/abs/2310.04378) is method to distill latent diffusion model to enable swift inference with minimal steps. This example demonstrates how to use the latent consistency distillation to distill SDXL for less timestep inference.
+
+## Full model distillation
+
+### Running locally with PyTorch
+
+#### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install -e .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Or for a default accelerate configuration without answering questions about your environment
+
+```bash
+accelerate config default
+```
+
+Or if your environment doesn't support an interactive shell e.g. a notebook
+
+```python
+from accelerate.utils import write_basic_config
+write_basic_config()
+```
+
+When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
+
+
+#### Example with LAION-A6+ dataset
+
+```bash
+export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
+PROGRAM="train_lcm_distill_sdxl_wds.py \
+ --pretrained_teacher_model=$MODEL_DIR \
+ --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision=fp16 \
+ --resolution=1024 \
+ --learning_rate=1e-6 --loss_type="huber" --use_fix_crop_and_size --ema_decay=0.95 --adam_weight_decay=0.0 \
+ --max_train_steps=1000 \
+ --max_train_samples=4000000 \
+ --dataloader_num_workers=8 \
+ --train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
+ --validation_steps=200 \
+ --checkpointing_steps=200 --checkpoints_total_limit=10 \
+ --train_batch_size=12 \
+ --gradient_checkpointing --enable_xformers_memory_efficient_attention \
+ --gradient_accumulation_steps=1 \
+ --use_8bit_adam \
+ --resume_from_checkpoint=latest \
+ --report_to=wandb \
+ --seed=453645634 \
+ --push_to_hub \
+```
+
+## LCM-LoRA
+
+Instead of fine-tuning the full model, we can also just train a LoRA that can be injected into any SDXL model.
+
+### Example with LAION-A6+ dataset
+
+```bash
+export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
+PROGRAM="train_lcm_distill_lora_sdxl_wds.py \
+ --pretrained_teacher_model=$MODEL_DIR \
+ --pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
+ --output_dir=$OUTPUT_DIR \
+ --mixed_precision=fp16 \
+ --resolution=1024 \
+ --lora_rank=64 \
+ --learning_rate=1e-6 --loss_type="huber" --use_fix_crop_and_size --adam_weight_decay=0.0 \
+ --max_train_steps=1000 \
+ --max_train_samples=4000000 \
+ --dataloader_num_workers=8 \
+ --train_shards_path_or_url='pipe:aws s3 cp s3://muse-datasets/laion-aesthetic6plus-min512-data/{00000..01210}.tar -' \
+ --validation_steps=200 \
+ --checkpointing_steps=200 --checkpoints_total_limit=10 \
+ --train_batch_size=12 \
+ --gradient_checkpointing --enable_xformers_memory_efficient_attention \
+ --gradient_accumulation_steps=1 \
+ --use_8bit_adam \
+ --resume_from_checkpoint=latest \
+ --report_to=wandb \
+ --seed=453645634 \
+ --push_to_hub \
+```
\ No newline at end of file
diff --git a/examples/consistency_distillation/requirements.txt b/examples/consistency_distillation/requirements.txt
new file mode 100644
index 000000000000..09fb84270a8a
--- /dev/null
+++ b/examples/consistency_distillation/requirements.txt
@@ -0,0 +1,7 @@
+accelerate>=0.16.0
+torchvision
+transformers>=4.25.1
+ftfy
+tensorboard
+Jinja2
+webdataset
\ No newline at end of file
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
new file mode 100644
index 000000000000..6fa8d2c57832
--- /dev/null
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
@@ -0,0 +1,1321 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import functools
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+from typing import List, Union
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision.transforms.functional as TF
+import transformers
+import webdataset as wds
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from braceexpand import braceexpand
+from huggingface_hub import create_repo
+from packaging import version
+from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
+from torch.utils.data import default_collate
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, CLIPTextModel, PretrainedConfig
+from webdataset.tariterators import (
+ base_plus_ext,
+ tar_file_expander,
+ url_opener,
+ valid_sample,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ LCMScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+MAX_SEQ_LENGTH = 77
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.18.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
+ kohya_ss_state_dict = {}
+ for peft_key, weight in get_peft_model_state_dict(module, adapter_name=adapter_name).items():
+ kohya_key = peft_key.replace("base_model.model", prefix)
+ kohya_key = kohya_key.replace("lora_A", "lora_down")
+ kohya_key = kohya_key.replace("lora_B", "lora_up")
+ kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
+ kohya_ss_state_dict[kohya_key] = weight.to(dtype)
+
+ # Set alpha parameter
+ if "lora_down" in kohya_key:
+ alpha_key = f'{kohya_key.split(".")[0]}.alpha'
+ kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
+
+ return kohya_ss_state_dict
+
+
+def filter_keys(key_set):
+ def _f(dictionary):
+ return {k: v for k, v in dictionary.items() if k in key_set}
+
+ return _f
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to
+ lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = {"__key__": prefix, "__url__": filesample["__url__"]}
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+class WebdatasetFilter:
+ def __init__(self, min_size=1024, max_pwatermark=0.5):
+ self.min_size = min_size
+ self.max_pwatermark = max_pwatermark
+
+ def __call__(self, x):
+ try:
+ if "json" in x:
+ x_json = json.loads(x["json"])
+ filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
+ "original_height", 0
+ ) >= self.min_size
+ filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
+ return filter_size and filter_watermark
+ else:
+ return False
+ except Exception:
+ return False
+
+
+class Text2ImageDataset:
+ def __init__(
+ self,
+ train_shards_path_or_url: Union[str, List[str]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers: int,
+ resolution: int = 512,
+ shuffle_buffer_size: int = 1000,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ ):
+ if not isinstance(train_shards_path_or_url, str):
+ train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
+ # flatten list using itertools
+ train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
+
+ def transform(example):
+ # resize image
+ image = example["image"]
+ image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+
+ # get crop coordinates and crop image
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = TF.crop(image, c_top, c_left, resolution, resolution)
+ image = TF.to_tensor(image)
+ image = TF.normalize(image, [0.5], [0.5])
+
+ example["image"] = image
+ return example
+
+ processing_pipeline = [
+ wds.decode("pil", handler=wds.ignore_and_continue),
+ wds.rename(image="jpg;png;jpeg;webp", text="text;txt;caption", handler=wds.warn_and_continue),
+ wds.map(filter_keys({"image", "text"})),
+ wds.map(transform),
+ wds.to_tuple("image", "text"),
+ ]
+
+ # Create train dataset and loader
+ pipeline = [
+ wds.ResampledShards(train_shards_path_or_url),
+ tarfile_to_samples_nothrow,
+ wds.shuffle(shuffle_buffer_size),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+
+ # each worker is iterating over this
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+ # add meta-data to dataloader instance for convenience
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ @property
+ def train_dataset(self):
+ return self._train_dataset
+
+ @property
+ def train_dataloader(self):
+ return self._train_dataloader
+
+
+def log_validation(vae, unet, args, accelerator, weight_dtype, step):
+ logger.info("Running validation... ")
+
+ unet = accelerator.unwrap_model(unet)
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_teacher_model,
+ vae=vae,
+ scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ safety_checker=None,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype)
+ pipeline.load_lora_weights(lora_state_dict)
+ pipeline.fuse_lora()
+
+ pipeline = pipeline.to(accelerator.device, dtype=weight_dtype)
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ validation_prompts = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+ ]
+
+ image_logs = []
+
+ for _, prompt in enumerate(validation_prompts):
+ images = []
+ with torch.autocast("cuda", dtype=weight_dtype):
+ images = pipeline(
+ prompt=prompt,
+ num_inference_steps=4,
+ num_images_per_prompt=4,
+ generator=generator,
+ guidance_scale=1.0,
+ ).images
+ image_logs.append({"validation_prompt": prompt, "images": images})
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ formatted_images = []
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({"validation": formatted_images})
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+# From LatentConsistencyModel.get_guidance_scale_embedding
+def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# From LCMScheduler.get_scalings_for_boundary_condition_discrete
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
+ c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
+ c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+# Compare LCMScheduler.step, Step 4
+def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ if prediction_type == "epsilon":
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ pred_x_0 = (sample - sigmas * model_output) / alphas
+ elif prediction_type == "v_prediction":
+ pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
+ else:
+ raise ValueError(f"Prediction type {prediction_type} currently not supported.")
+
+ return pred_x_0
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+class DDIMSolver:
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
+ # DDIM sampling parameters
+ step_ratio = timesteps // ddim_timesteps
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
+ self.ddim_alpha_cumprods_prev = np.asarray(
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
+ )
+ # convert to torch tensors
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
+
+ def to(self, device):
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
+ return self
+
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev
+
+
+@torch.no_grad()
+def update_ema(target_params, source_params, rate=0.99):
+ """
+ Update target parameters to be closer to those of source parameters using
+ an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ # ----------Model Checkpoint Loading Arguments----------
+ parser.add_argument(
+ "--pretrained_teacher_model",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--teacher_revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
+ )
+ # ----------Training Arguments----------
+ # ----General Training Arguments----
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lcm-xl-distilled",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ # ----Logging----
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ # ----Checkpointing----
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ # ----Image Processing----
+ parser.add_argument(
+ "--train_shards_path_or_url",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ # ----Dataloader----
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ # ----Batch Size and Training Steps----
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ # ----Learning Rate----
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ # ----Optimizer (Adam)----
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ # ----Diffusion Training Arguments----
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
+ parser.add_argument(
+ "--w_min",
+ type=float,
+ default=5.0,
+ required=False,
+ help=(
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--w_max",
+ type=float,
+ default=15.0,
+ required=False,
+ help=(
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--num_ddim_timesteps",
+ type=int,
+ default=50,
+ help="The number of timesteps to use for DDIM sampling.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber"],
+ help="The type of loss to use for the LCD loss.",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.001,
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
+ )
+ parser.add_argument(
+ "--lora_rank",
+ type=int,
+ default=64,
+ help="The rank of the LoRA projection matrix.",
+ )
+ # ----Mixed Precision----
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cast_teacher_unet",
+ action="store_true",
+ help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.",
+ )
+ # ----Training Optimizations----
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ # ----Distributed Training----
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ # ----------Validation Arguments----------
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ # ----------Huggingface Hub Arguments-----------
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ # ----------Accelerate Arguments----------
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]
+
+ return prompt_embeds
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # 1. Create the noise scheduler and the desired noise schedule.
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
+ )
+
+ # The scheduler calculates the alpha and sigma schedule for us
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
+ solver = DDIMSolver(
+ noise_scheduler.alphas_cumprod.numpy(),
+ timesteps=noise_scheduler.config.num_train_timesteps,
+ ddim_timesteps=args.num_ddim_timesteps,
+ )
+
+ # 2. Load tokenizers from SD-XL checkpoint.
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
+ )
+
+ # 3. Load text encoders from SD-1.5 checkpoint.
+ # import correct text encoder classes
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
+ )
+
+ # 4. Load VAE from SD-XL checkpoint (or more stable VAE)
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_teacher_model,
+ subfolder="vae",
+ revision=args.teacher_revision,
+ )
+
+ # 5. Load teacher U-Net from SD-XL checkpoint
+ teacher_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+
+ # 6. Freeze teacher vae, text_encoder, and teacher_unet
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ teacher_unet.requires_grad_(False)
+
+ # 7. Create online (`unet`) student U-Nets.
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+ unet.train()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
+
+ # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
+ lora_config = LoraConfig(
+ r=args.lora_rank,
+ target_modules=[
+ "to_q",
+ "to_k",
+ "to_v",
+ "to_out.0",
+ "proj_in",
+ "proj_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "conv1",
+ "conv2",
+ "conv_shortcut",
+ "downsamplers.0.conv",
+ "upsamplers.0.conv",
+ "time_emb_proj",
+ ],
+ )
+ unet = get_peft_model(unet, lora_config)
+
+ # 9. Handle mixed precision and device placement
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device)
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # Move teacher_unet to device, optionally cast to weight_dtype
+ teacher_unet.to(accelerator.device)
+ if args.cast_teacher_unet:
+ teacher_unet.to(dtype=weight_dtype)
+
+ # Also move the alpha and sigma noise schedules to accelerator.device.
+ alpha_schedule = alpha_schedule.to(accelerator.device)
+ sigma_schedule = sigma_schedule.to(accelerator.device)
+ solver = solver.to(accelerator.device)
+
+ # 10. Handle saving and loading of checkpoints
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ unet_ = accelerator.unwrap_model(unet)
+ lora_state_dict = get_peft_model_state_dict(unet_, adapter_name="default")
+ StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict)
+ # save weights in peft format to be able to load them back
+ unet_.save_pretrained(output_dir)
+
+ for _, model in enumerate(models):
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ # load the LoRA into the model
+ unet_ = accelerator.unwrap_model(unet)
+ unet_.load_adapter(input_dir, "default", is_trainable=True)
+
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ models.pop()
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # 11. Enable optimizations
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ teacher_unet.enable_xformers_memory_efficient_attention()
+ # target_unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # 12. Optimizer creation
+ optimizer = optimizer_class(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
+ prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
+ return {"prompt_embeds": prompt_embeds}
+
+ dataset = Text2ImageDataset(
+ train_shards_path_or_url=args.train_shards_path_or_url,
+ num_train_examples=args.max_train_samples,
+ per_gpu_batch_size=args.train_batch_size,
+ global_batch_size=args.train_batch_size * accelerator.num_processes,
+ num_workers=args.dataloader_num_workers,
+ resolution=args.resolution,
+ shuffle_buffer_size=1000,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+ train_dataloader = dataset.train_dataloader
+
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=0,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ uncond_input_ids = tokenizer(
+ [""] * args.train_batch_size, return_tensors="pt", padding="max_length", max_length=77
+ ).input_ids.to(accelerator.device)
+ uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num batches each epoch = {train_dataloader.num_batches}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ image, text, _, _ = batch
+
+ image = image.to(accelerator.device, non_blocking=True)
+ encoded_text = compute_embeddings_fn(text)
+
+ pixel_values = image.to(dtype=weight_dtype)
+ if vae.dtype != weight_dtype:
+ vae.to(dtype=weight_dtype)
+
+ # encode pixel values with batch size of at most 32
+ latents = []
+ for i in range(0, pixel_values.shape[0], 32):
+ latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+
+ latents = latents * vae.config.scaling_factor
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
+ start_timesteps = solver.ddim_timesteps[index]
+ timesteps = start_timesteps - topk
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
+
+ # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
+ c_skip, c_out = scalings_for_boundary_conditions(timesteps)
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
+
+ # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
+
+ # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
+ w = w.reshape(bsz, 1, 1, 1)
+ w = w.to(device=latents.device, dtype=latents.dtype)
+
+ # 20.4.8. Prepare prompt embeds and unet_added_conditions
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+
+ # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
+ noise_pred = unet(
+ noisy_model_input,
+ start_timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+
+ pred_x_0 = predicted_origin(
+ noise_pred,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
+
+ # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
+ # noisy_latents with both the conditioning embedding c and unconditional embedding 0
+ # Get teacher model prediction on noisy_latents and conditional embedding
+ with torch.no_grad():
+ with torch.autocast("cuda"):
+ cond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=prompt_embeds.to(weight_dtype),
+ ).sample
+ cond_pred_x0 = predicted_origin(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # Get teacher model prediction on noisy_latents and unconditional embedding
+ uncond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
+ ).sample
+ uncond_pred_x0 = predicted_origin(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
+ pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index)
+
+ # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
+ with torch.no_grad():
+ with torch.autocast("cuda", dtype=weight_dtype):
+ target_noise_pred = unet(
+ x_prev.float(),
+ timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ ).sample
+ pred_x_0 = predicted_origin(
+ target_noise_pred,
+ timesteps,
+ x_prev,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ target = c_skip * x_prev + c_out * pred_x_0
+
+ # 20.4.13. Calculate loss
+ if args.loss_type == "l2":
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ elif args.loss_type == "huber":
+ loss = torch.mean(
+ torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
+ )
+
+ # 20.4.14. Backpropagate on the online student model (`unet`)
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ log_validation(vae, unet, args, accelerator, weight_dtype, global_step)
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet.save_pretrained(args.output_dir)
+ lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
+ StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
new file mode 100644
index 000000000000..25faedf714b9
--- /dev/null
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
@@ -0,0 +1,1377 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import functools
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+from typing import List, Union
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision.transforms.functional as TF
+import transformers
+import webdataset as wds
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from braceexpand import braceexpand
+from huggingface_hub import create_repo
+from packaging import version
+from peft import LoraConfig, get_peft_model, get_peft_model_state_dict
+from torch.utils.data import default_collate
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+from webdataset.tariterators import (
+ base_plus_ext,
+ tar_file_expander,
+ url_opener,
+ valid_sample,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ LCMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+MAX_SEQ_LENGTH = 77
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.18.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"):
+ kohya_ss_state_dict = {}
+ for peft_key, weight in get_peft_model_state_dict(module, adapter_name=adapter_name).items():
+ kohya_key = peft_key.replace("base_model.model", prefix)
+ kohya_key = kohya_key.replace("lora_A", "lora_down")
+ kohya_key = kohya_key.replace("lora_B", "lora_up")
+ kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
+ kohya_ss_state_dict[kohya_key] = weight.to(dtype)
+
+ # Set alpha parameter
+ if "lora_down" in kohya_key:
+ alpha_key = f'{kohya_key.split(".")[0]}.alpha'
+ kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype)
+
+ return kohya_ss_state_dict
+
+
+def filter_keys(key_set):
+ def _f(dictionary):
+ return {k: v for k, v in dictionary.items() if k in key_set}
+
+ return _f
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to
+ lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = {"__key__": prefix, "__url__": filesample["__url__"]}
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+class WebdatasetFilter:
+ def __init__(self, min_size=1024, max_pwatermark=0.5):
+ self.min_size = min_size
+ self.max_pwatermark = max_pwatermark
+
+ def __call__(self, x):
+ try:
+ if "json" in x:
+ x_json = json.loads(x["json"])
+ filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
+ "original_height", 0
+ ) >= self.min_size
+ filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
+ return filter_size and filter_watermark
+ else:
+ return False
+ except Exception:
+ return False
+
+
+class Text2ImageDataset:
+ def __init__(
+ self,
+ train_shards_path_or_url: Union[str, List[str]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers: int,
+ resolution: int = 1024,
+ shuffle_buffer_size: int = 1000,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ use_fix_crop_and_size: bool = False,
+ ):
+ if not isinstance(train_shards_path_or_url, str):
+ train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
+ # flatten list using itertools
+ train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
+
+ def get_orig_size(json):
+ if use_fix_crop_and_size:
+ return (resolution, resolution)
+ else:
+ return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
+
+ def transform(example):
+ # resize image
+ image = example["image"]
+ image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+
+ # get crop coordinates and crop image
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = TF.crop(image, c_top, c_left, resolution, resolution)
+ image = TF.to_tensor(image)
+ image = TF.normalize(image, [0.5], [0.5])
+
+ example["image"] = image
+ example["crop_coords"] = (c_top, c_left) if not use_fix_crop_and_size else (0, 0)
+ return example
+
+ processing_pipeline = [
+ wds.decode("pil", handler=wds.ignore_and_continue),
+ wds.rename(
+ image="jpg;png;jpeg;webp", text="text;txt;caption", orig_size="json", handler=wds.warn_and_continue
+ ),
+ wds.map(filter_keys({"image", "text", "orig_size"})),
+ wds.map_dict(orig_size=get_orig_size),
+ wds.map(transform),
+ wds.to_tuple("image", "text", "orig_size", "crop_coords"),
+ ]
+
+ # Create train dataset and loader
+ pipeline = [
+ wds.ResampledShards(train_shards_path_or_url),
+ tarfile_to_samples_nothrow,
+ wds.select(WebdatasetFilter(min_size=960)),
+ wds.shuffle(shuffle_buffer_size),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+
+ # each worker is iterating over this
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+ # add meta-data to dataloader instance for convenience
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ @property
+ def train_dataset(self):
+ return self._train_dataset
+
+ @property
+ def train_dataloader(self):
+ return self._train_dataloader
+
+
+def log_validation(vae, unet, args, accelerator, weight_dtype, step):
+ logger.info("Running validation... ")
+
+ unet = accelerator.unwrap_model(unet)
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_teacher_model,
+ vae=vae,
+ scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype)
+ pipeline.load_lora_weights(lora_state_dict)
+ pipeline.fuse_lora()
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ validation_prompts = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+ ]
+
+ image_logs = []
+
+ for _, prompt in enumerate(validation_prompts):
+ images = []
+ with torch.autocast("cuda", dtype=weight_dtype):
+ images = pipeline(
+ prompt=prompt,
+ num_inference_steps=4,
+ num_images_per_prompt=4,
+ generator=generator,
+ guidance_scale=0.0,
+ ).images
+ image_logs.append({"validation_prompt": prompt, "images": images})
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ formatted_images = []
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({"validation": formatted_images})
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# From LCMScheduler.get_scalings_for_boundary_condition_discrete
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
+ c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
+ c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+# Compare LCMScheduler.step, Step 4
+def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ if prediction_type == "epsilon":
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ pred_x_0 = (sample - sigmas * model_output) / alphas
+ elif prediction_type == "v_prediction":
+ pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
+ else:
+ raise ValueError(f"Prediction type {prediction_type} currently not supported.")
+
+ return pred_x_0
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+class DDIMSolver:
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
+ # DDIM sampling parameters
+ step_ratio = timesteps // ddim_timesteps
+
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
+ self.ddim_alpha_cumprods_prev = np.asarray(
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
+ )
+ # convert to torch tensors
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
+
+ def to(self, device):
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
+ return self
+
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ # ----------Model Checkpoint Loading Arguments----------
+ parser.add_argument(
+ "--pretrained_teacher_model",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--teacher_revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
+ )
+ # ----------Training Arguments----------
+ # ----General Training Arguments----
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lcm-xl-distilled",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ # ----Logging----
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ # ----Checkpointing----
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ # ----Image Processing----
+ parser.add_argument(
+ "--train_shards_path_or_url",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--use_fix_crop_and_size",
+ action="store_true",
+ help="Whether or not to use the fixed crop and size for the teacher model.",
+ default=False,
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ # ----Dataloader----
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ # ----Batch Size and Training Steps----
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ # ----Learning Rate----
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ # ----Optimizer (Adam)----
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ # ----Diffusion Training Arguments----
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
+ parser.add_argument(
+ "--w_min",
+ type=float,
+ default=3.0,
+ required=False,
+ help=(
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--w_max",
+ type=float,
+ default=15.0,
+ required=False,
+ help=(
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--num_ddim_timesteps",
+ type=int,
+ default=50,
+ help="The number of timesteps to use for DDIM sampling.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber"],
+ help="The type of loss to use for the LCD loss.",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.001,
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
+ )
+ parser.add_argument(
+ "--lora_rank",
+ type=int,
+ default=64,
+ help="The rank of the LoRA projection matrix.",
+ )
+ # ----Mixed Precision----
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cast_teacher_unet",
+ action="store_true",
+ help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.",
+ )
+ # ----Training Optimizations----
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ # ----Distributed Training----
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ # ----------Validation Arguments----------
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ # ----------Huggingface Hub Arguments-----------
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ # ----------Accelerate Arguments----------
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
+ prompt_embeds_list = []
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # 1. Create the noise scheduler and the desired noise schedule.
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
+ )
+
+ # The scheduler calculates the alpha and sigma schedule for us
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
+ solver = DDIMSolver(
+ noise_scheduler.alphas_cumprod.numpy(),
+ timesteps=noise_scheduler.config.num_train_timesteps,
+ ddim_timesteps=args.num_ddim_timesteps,
+ )
+
+ # 2. Load tokenizers from SD-XL checkpoint.
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False
+ )
+
+ # 3. Load text encoders from SD-XL checkpoint.
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_teacher_model, args.teacher_revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_teacher_model, args.teacher_revision, subfolder="text_encoder_2"
+ )
+
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder_2", revision=args.teacher_revision
+ )
+
+ # 4. Load VAE from SD-XL checkpoint (or more stable VAE)
+ vae_path = (
+ args.pretrained_teacher_model
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.teacher_revision,
+ )
+
+ # 5. Load teacher U-Net from SD-XL checkpoint
+ teacher_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+
+ # 6. Freeze teacher vae, text_encoders, and teacher_unet
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ teacher_unet.requires_grad_(False)
+
+ # 7. Create online (`unet`) student U-Nets.
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+ unet.train()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
+
+ # 8. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
+ lora_config = LoraConfig(
+ r=args.lora_rank,
+ target_modules=[
+ "to_q",
+ "to_k",
+ "to_v",
+ "to_out.0",
+ "proj_in",
+ "proj_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "conv1",
+ "conv2",
+ "conv_shortcut",
+ "downsamplers.0.conv",
+ "upsamplers.0.conv",
+ "time_emb_proj",
+ ],
+ )
+ unet = get_peft_model(unet, lora_config)
+
+ # 9. Handle mixed precision and device placement
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device)
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # Move teacher_unet to device, optionally cast to weight_dtype
+ teacher_unet.to(accelerator.device)
+ if args.cast_teacher_unet:
+ teacher_unet.to(dtype=weight_dtype)
+
+ # Also move the alpha and sigma noise schedules to accelerator.device.
+ alpha_schedule = alpha_schedule.to(accelerator.device)
+ sigma_schedule = sigma_schedule.to(accelerator.device)
+ solver = solver.to(accelerator.device)
+
+ # 10. Handle saving and loading of checkpoints
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ unet_ = accelerator.unwrap_model(unet)
+ lora_state_dict = get_peft_model_state_dict(unet_, adapter_name="default")
+ StableDiffusionXLPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict)
+ # save weights in peft format to be able to load them back
+ unet_.save_pretrained(output_dir)
+
+ for _, model in enumerate(models):
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ # load the LoRA into the model
+ unet_ = accelerator.unwrap_model(unet)
+ unet_.load_adapter(input_dir, "default", is_trainable=True)
+
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ models.pop()
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # 11. Enable optimizations
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ teacher_unet.enable_xformers_memory_efficient_attention()
+ # target_unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # 12. Optimizer creation
+ optimizer = optimizer_class(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # 13. Dataset creation and data processing
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(
+ prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True
+ ):
+ target_size = (args.resolution, args.resolution)
+ original_sizes = list(map(list, zip(*original_sizes)))
+ crops_coords_top_left = list(map(list, zip(*crop_coords)))
+
+ original_sizes = torch.tensor(original_sizes, dtype=torch.long)
+ crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long)
+
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
+ )
+ add_text_embeds = pooled_prompt_embeds
+
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ add_time_ids = list(target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
+ add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1)
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ add_text_embeds = add_text_embeds.to(accelerator.device)
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
+
+ dataset = Text2ImageDataset(
+ train_shards_path_or_url=args.train_shards_path_or_url,
+ num_train_examples=args.max_train_samples,
+ per_gpu_batch_size=args.train_batch_size,
+ global_batch_size=args.train_batch_size * accelerator.num_processes,
+ num_workers=args.dataloader_num_workers,
+ resolution=args.resolution,
+ shuffle_buffer_size=1000,
+ pin_memory=True,
+ persistent_workers=True,
+ use_fix_crop_and_size=args.use_fix_crop_and_size,
+ )
+ train_dataloader = dataset.train_dataloader
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=0,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ )
+
+ # 14. LR Scheduler creation
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # 15. Prepare for training
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Create uncond embeds for classifier free guidance
+ uncond_prompt_embeds = torch.zeros(args.train_batch_size, 77, 2048).to(accelerator.device)
+ uncond_pooled_prompt_embeds = torch.zeros(args.train_batch_size, 1280).to(accelerator.device)
+
+ # 16. Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num batches each epoch = {train_dataloader.num_batches}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ image, text, orig_size, crop_coords = batch
+
+ image = image.to(accelerator.device, non_blocking=True)
+ encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
+
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = image.to(dtype=weight_dtype)
+ if vae.dtype != weight_dtype:
+ vae.to(dtype=weight_dtype)
+ else:
+ pixel_values = image
+
+ # encode pixel values with batch size of at most 8
+ latents = []
+ for i in range(0, pixel_values.shape[0], 8):
+ latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
+ start_timesteps = solver.ddim_timesteps[index]
+ timesteps = start_timesteps - topk
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
+
+ # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
+ c_skip, c_out = scalings_for_boundary_conditions(timesteps)
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
+
+ # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
+
+ # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
+ w = w.reshape(bsz, 1, 1, 1)
+ w = w.to(device=latents.device, dtype=latents.dtype)
+
+ # 20.4.8. Prepare prompt embeds and unet_added_conditions
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+
+ # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
+ noise_pred = unet(
+ noisy_model_input,
+ start_timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+
+ pred_x_0 = predicted_origin(
+ noise_pred,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
+
+ # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
+ # noisy_latents with both the conditioning embedding c and unconditional embedding 0
+ # Get teacher model prediction on noisy_latents and conditional embedding
+ with torch.no_grad():
+ with torch.autocast("cuda"):
+ cond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=prompt_embeds.to(weight_dtype),
+ added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
+ ).sample
+ cond_pred_x0 = predicted_origin(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # Get teacher model prediction on noisy_latents and unconditional embedding
+ uncond_added_conditions = copy.deepcopy(encoded_text)
+ uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
+ uncond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
+ added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
+ ).sample
+ uncond_pred_x0 = predicted_origin(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
+ pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index)
+
+ # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
+ with torch.no_grad():
+ with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
+ target_noise_pred = unet(
+ x_prev.float(),
+ timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+ pred_x_0 = predicted_origin(
+ target_noise_pred,
+ timesteps,
+ x_prev,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ target = c_skip * x_prev + c_out * pred_x_0
+
+ # 20.4.13. Calculate loss
+ if args.loss_type == "l2":
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ elif args.loss_type == "huber":
+ loss = torch.mean(
+ torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
+ )
+
+ # 20.4.14. Backpropagate on the online student model (`unet`)
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ log_validation(vae, unet, args, accelerator, weight_dtype, global_step)
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet.save_pretrained(args.output_dir)
+ lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default")
+ StableDiffusionXLPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict)
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py
new file mode 100644
index 000000000000..ec4bf432f03d
--- /dev/null
+++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py
@@ -0,0 +1,1302 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import functools
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+from typing import List, Union
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision.transforms.functional as TF
+import transformers
+import webdataset as wds
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from braceexpand import braceexpand
+from huggingface_hub import create_repo
+from packaging import version
+from torch.utils.data import default_collate
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, CLIPTextModel, PretrainedConfig
+from webdataset.tariterators import (
+ base_plus_ext,
+ tar_file_expander,
+ url_opener,
+ valid_sample,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ LCMScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+MAX_SEQ_LENGTH = 77
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.18.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def filter_keys(key_set):
+ def _f(dictionary):
+ return {k: v for k, v in dictionary.items() if k in key_set}
+
+ return _f
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to
+ lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = {"__key__": prefix, "__url__": filesample["__url__"]}
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+class WebdatasetFilter:
+ def __init__(self, min_size=1024, max_pwatermark=0.5):
+ self.min_size = min_size
+ self.max_pwatermark = max_pwatermark
+
+ def __call__(self, x):
+ try:
+ if "json" in x:
+ x_json = json.loads(x["json"])
+ filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
+ "original_height", 0
+ ) >= self.min_size
+ filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
+ return filter_size and filter_watermark
+ else:
+ return False
+ except Exception:
+ return False
+
+
+class Text2ImageDataset:
+ def __init__(
+ self,
+ train_shards_path_or_url: Union[str, List[str]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers: int,
+ resolution: int = 512,
+ shuffle_buffer_size: int = 1000,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ ):
+ if not isinstance(train_shards_path_or_url, str):
+ train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
+ # flatten list using itertools
+ train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
+
+ def transform(example):
+ # resize image
+ image = example["image"]
+ image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+
+ # get crop coordinates and crop image
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = TF.crop(image, c_top, c_left, resolution, resolution)
+ image = TF.to_tensor(image)
+ image = TF.normalize(image, [0.5], [0.5])
+
+ example["image"] = image
+ return example
+
+ processing_pipeline = [
+ wds.decode("pil", handler=wds.ignore_and_continue),
+ wds.rename(image="jpg;png;jpeg;webp", text="text;txt;caption", handler=wds.warn_and_continue),
+ wds.map(filter_keys({"image", "text"})),
+ wds.map(transform),
+ wds.to_tuple("image", "text"),
+ ]
+
+ # Create train dataset and loader
+ pipeline = [
+ wds.ResampledShards(train_shards_path_or_url),
+ tarfile_to_samples_nothrow,
+ wds.shuffle(shuffle_buffer_size),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+
+ # each worker is iterating over this
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+ # add meta-data to dataloader instance for convenience
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ @property
+ def train_dataset(self):
+ return self._train_dataset
+
+ @property
+ def train_dataloader(self):
+ return self._train_dataloader
+
+
+def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="target"):
+ logger.info("Running validation... ")
+
+ unet = accelerator.unwrap_model(unet)
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_teacher_model,
+ vae=vae,
+ unet=unet,
+ scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ validation_prompts = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+ ]
+
+ image_logs = []
+
+ for _, prompt in enumerate(validation_prompts):
+ images = []
+ with torch.autocast("cuda"):
+ images = pipeline(
+ prompt=prompt,
+ num_inference_steps=4,
+ num_images_per_prompt=4,
+ generator=generator,
+ ).images
+ image_logs.append({"validation_prompt": prompt, "images": images})
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ formatted_images = []
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({f"validation/{name}": formatted_images})
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+# From LatentConsistencyModel.get_guidance_scale_embedding
+def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# From LCMScheduler.get_scalings_for_boundary_condition_discrete
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
+ c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
+ c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+# Compare LCMScheduler.step, Step 4
+def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ if prediction_type == "epsilon":
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ pred_x_0 = (sample - sigmas * model_output) / alphas
+ elif prediction_type == "v_prediction":
+ pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
+ else:
+ raise ValueError(f"Prediction type {prediction_type} currently not supported.")
+
+ return pred_x_0
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+class DDIMSolver:
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
+ # DDIM sampling parameters
+ step_ratio = timesteps // ddim_timesteps
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
+ self.ddim_alpha_cumprods_prev = np.asarray(
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
+ )
+ # convert to torch tensors
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
+
+ def to(self, device):
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
+ return self
+
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev
+
+
+@torch.no_grad()
+def update_ema(target_params, source_params, rate=0.99):
+ """
+ Update target parameters to be closer to those of source parameters using
+ an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ # ----------Model Checkpoint Loading Arguments----------
+ parser.add_argument(
+ "--pretrained_teacher_model",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--teacher_revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
+ )
+ # ----------Training Arguments----------
+ # ----General Training Arguments----
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lcm-xl-distilled",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ # ----Logging----
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ # ----Checkpointing----
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ # ----Image Processing----
+ parser.add_argument(
+ "--train_shards_path_or_url",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ # ----Dataloader----
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ # ----Batch Size and Training Steps----
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ # ----Learning Rate----
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ # ----Optimizer (Adam)----
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ # ----Diffusion Training Arguments----
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
+ parser.add_argument(
+ "--w_min",
+ type=float,
+ default=5.0,
+ required=False,
+ help=(
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--w_max",
+ type=float,
+ default=15.0,
+ required=False,
+ help=(
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--num_ddim_timesteps",
+ type=int,
+ default=50,
+ help="The number of timesteps to use for DDIM sampling.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber"],
+ help="The type of loss to use for the LCD loss.",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.001,
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
+ )
+ # ----Exponential Moving Average (EMA)----
+ parser.add_argument(
+ "--ema_decay",
+ type=float,
+ default=0.95,
+ required=False,
+ help="The exponential moving average (EMA) rate or decay factor.",
+ )
+ # ----Mixed Precision----
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cast_teacher_unet",
+ action="store_true",
+ help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.",
+ )
+ # ----Training Optimizations----
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ # ----Distributed Training----
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ # ----------Validation Arguments----------
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ # ----------Huggingface Hub Arguments-----------
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ # ----------Accelerate Arguments----------
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]
+
+ return prompt_embeds
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # 1. Create the noise scheduler and the desired noise schedule.
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
+ )
+
+ # The scheduler calculates the alpha and sigma schedule for us
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
+ solver = DDIMSolver(
+ noise_scheduler.alphas_cumprod.numpy(),
+ timesteps=noise_scheduler.config.num_train_timesteps,
+ ddim_timesteps=args.num_ddim_timesteps,
+ )
+
+ # 2. Load tokenizers from SD-XL checkpoint.
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
+ )
+
+ # 3. Load text encoders from SD-1.5 checkpoint.
+ # import correct text encoder classes
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
+ )
+
+ # 4. Load VAE from SD-XL checkpoint (or more stable VAE)
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_teacher_model,
+ subfolder="vae",
+ revision=args.teacher_revision,
+ )
+
+ # 5. Load teacher U-Net from SD-XL checkpoint
+ teacher_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+
+ # 6. Freeze teacher vae, text_encoder, and teacher_unet
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ teacher_unet.requires_grad_(False)
+
+ # 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.)
+ # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
+ if teacher_unet.config.time_cond_proj_dim is None:
+ teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
+ unet = UNet2DConditionModel(**teacher_unet.config)
+ # load teacher_unet weights into unet
+ unet.load_state_dict(teacher_unet.state_dict(), strict=False)
+ unet.train()
+
+ # 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging).
+ # Initialize from unet
+ target_unet = UNet2DConditionModel(**teacher_unet.config)
+ target_unet.load_state_dict(unet.state_dict())
+ target_unet.train()
+ target_unet.requires_grad_(False)
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
+
+ # 10. Handle mixed precision and device placement
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device)
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # Move teacher_unet to device, optionally cast to weight_dtype
+ target_unet.to(accelerator.device)
+ teacher_unet.to(accelerator.device)
+ if args.cast_teacher_unet:
+ teacher_unet.to(dtype=weight_dtype)
+
+ # Also move the alpha and sigma noise schedules to accelerator.device.
+ alpha_schedule = alpha_schedule.to(accelerator.device)
+ sigma_schedule = sigma_schedule.to(accelerator.device)
+ solver = solver.to(accelerator.device)
+
+ # 11. Handle saving and loading of checkpoints
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ target_unet.save_pretrained(os.path.join(output_dir, "unet_target"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target"))
+ target_unet.load_state_dict(load_model.state_dict())
+ target_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # 12. Enable optimizations
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ teacher_unet.enable_xformers_memory_efficient_attention()
+ target_unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # 12. Optimizer creation
+ optimizer = optimizer_class(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True):
+ prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train)
+ return {"prompt_embeds": prompt_embeds}
+
+ dataset = Text2ImageDataset(
+ train_shards_path_or_url=args.train_shards_path_or_url,
+ num_train_examples=args.max_train_samples,
+ per_gpu_batch_size=args.train_batch_size,
+ global_batch_size=args.train_batch_size * accelerator.num_processes,
+ num_workers=args.dataloader_num_workers,
+ resolution=args.resolution,
+ shuffle_buffer_size=1000,
+ pin_memory=True,
+ persistent_workers=True,
+ )
+ train_dataloader = dataset.train_dataloader
+
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=0,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ uncond_input_ids = tokenizer(
+ [""] * args.train_batch_size, return_tensors="pt", padding="max_length", max_length=77
+ ).input_ids.to(accelerator.device)
+ uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num batches each epoch = {train_dataloader.num_batches}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ image, text, _, _ = batch
+
+ image = image.to(accelerator.device, non_blocking=True)
+ encoded_text = compute_embeddings_fn(text)
+
+ pixel_values = image.to(dtype=weight_dtype)
+ if vae.dtype != weight_dtype:
+ vae.to(dtype=weight_dtype)
+
+ # encode pixel values with batch size of at most 32
+ latents = []
+ for i in range(0, pixel_values.shape[0], 32):
+ latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+
+ latents = latents * vae.config.scaling_factor
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
+ start_timesteps = solver.ddim_timesteps[index]
+ timesteps = start_timesteps - topk
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
+
+ # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
+ c_skip, c_out = scalings_for_boundary_conditions(timesteps)
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
+
+ # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
+
+ # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
+ w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
+ w = w.reshape(bsz, 1, 1, 1)
+ # Move to U-Net device and dtype
+ w = w.to(device=latents.device, dtype=latents.dtype)
+ w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
+
+ # 20.4.8. Prepare prompt embeds and unet_added_conditions
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+
+ # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
+ noise_pred = unet(
+ noisy_model_input,
+ start_timesteps,
+ timestep_cond=w_embedding,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+
+ pred_x_0 = predicted_origin(
+ noise_pred,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
+
+ # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
+ # noisy_latents with both the conditioning embedding c and unconditional embedding 0
+ # Get teacher model prediction on noisy_latents and conditional embedding
+ with torch.no_grad():
+ with torch.autocast("cuda"):
+ cond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=prompt_embeds.to(weight_dtype),
+ ).sample
+ cond_pred_x0 = predicted_origin(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # Get teacher model prediction on noisy_latents and unconditional embedding
+ uncond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
+ ).sample
+ uncond_pred_x0 = predicted_origin(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
+ pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index)
+
+ # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
+ with torch.no_grad():
+ with torch.autocast("cuda", dtype=weight_dtype):
+ target_noise_pred = target_unet(
+ x_prev.float(),
+ timesteps,
+ timestep_cond=w_embedding,
+ encoder_hidden_states=prompt_embeds.float(),
+ ).sample
+ pred_x_0 = predicted_origin(
+ target_noise_pred,
+ timesteps,
+ x_prev,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ target = c_skip * x_prev + c_out * pred_x_0
+
+ # 20.4.13. Calculate loss
+ if args.loss_type == "l2":
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ elif args.loss_type == "huber":
+ loss = torch.mean(
+ torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
+ )
+
+ # 20.4.14. Backpropagate on the online student model (`unet`)
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ # 20.4.15. Make EMA update to target student model parameters
+ update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, "target")
+ log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "online")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet.save_pretrained(os.path.join(args.output_dir, "unet"))
+
+ target_unet = accelerator.unwrap_model(target_unet)
+ target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
new file mode 100644
index 000000000000..7d2b1e103208
--- /dev/null
+++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
@@ -0,0 +1,1399 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import functools
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+from typing import List, Union
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import torchvision.transforms.functional as TF
+import transformers
+import webdataset as wds
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from braceexpand import braceexpand
+from huggingface_hub import create_repo
+from packaging import version
+from torch.utils.data import default_collate
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+from webdataset.tariterators import (
+ base_plus_ext,
+ tar_file_expander,
+ url_opener,
+ valid_sample,
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ LCMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.import_utils import is_xformers_available
+
+
+MAX_SEQ_LENGTH = 77
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.18.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def filter_keys(key_set):
+ def _f(dictionary):
+ return {k: v for k, v in dictionary.items() if k in key_set}
+
+ return _f
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+ """Return function over iterator that groups key, value pairs into samples.
+
+ :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to
+ lower case (Default value = True)
+ """
+ current_sample = None
+ for filesample in data:
+ assert isinstance(filesample, dict)
+ fname, value = filesample["fname"], filesample["data"]
+ prefix, suffix = keys(fname)
+ if prefix is None:
+ continue
+ if lcase:
+ suffix = suffix.lower()
+ # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+ # this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+ # begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+ if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+ if valid_sample(current_sample):
+ yield current_sample
+ current_sample = {"__key__": prefix, "__url__": filesample["__url__"]}
+ if suffixes is None or suffix in suffixes:
+ current_sample[suffix] = value
+ if valid_sample(current_sample):
+ yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue):
+ # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+ streams = url_opener(src, handler=handler)
+ files = tar_file_expander(streams, handler=handler)
+ samples = group_by_keys_nothrow(files, handler=handler)
+ return samples
+
+
+class WebdatasetFilter:
+ def __init__(self, min_size=1024, max_pwatermark=0.5):
+ self.min_size = min_size
+ self.max_pwatermark = max_pwatermark
+
+ def __call__(self, x):
+ try:
+ if "json" in x:
+ x_json = json.loads(x["json"])
+ filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get(
+ "original_height", 0
+ ) >= self.min_size
+ filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark
+ return filter_size and filter_watermark
+ else:
+ return False
+ except Exception:
+ return False
+
+
+class Text2ImageDataset:
+ def __init__(
+ self,
+ train_shards_path_or_url: Union[str, List[str]],
+ num_train_examples: int,
+ per_gpu_batch_size: int,
+ global_batch_size: int,
+ num_workers: int,
+ resolution: int = 1024,
+ shuffle_buffer_size: int = 1000,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ use_fix_crop_and_size: bool = False,
+ ):
+ if not isinstance(train_shards_path_or_url, str):
+ train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
+ # flatten list using itertools
+ train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url))
+
+ def get_orig_size(json):
+ if use_fix_crop_and_size:
+ return (resolution, resolution)
+ else:
+ return (int(json.get("original_width", 0.0)), int(json.get("original_height", 0.0)))
+
+ def transform(example):
+ # resize image
+ image = example["image"]
+ image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+
+ # get crop coordinates and crop image
+ c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
+ image = TF.crop(image, c_top, c_left, resolution, resolution)
+ image = TF.to_tensor(image)
+ image = TF.normalize(image, [0.5], [0.5])
+
+ example["image"] = image
+ example["crop_coords"] = (c_top, c_left) if not use_fix_crop_and_size else (0, 0)
+ return example
+
+ processing_pipeline = [
+ wds.decode("pil", handler=wds.ignore_and_continue),
+ wds.rename(
+ image="jpg;png;jpeg;webp", text="text;txt;caption", orig_size="json", handler=wds.warn_and_continue
+ ),
+ wds.map(filter_keys({"image", "text", "orig_size"})),
+ wds.map_dict(orig_size=get_orig_size),
+ wds.map(transform),
+ wds.to_tuple("image", "text", "orig_size", "crop_coords"),
+ ]
+
+ # Create train dataset and loader
+ pipeline = [
+ wds.ResampledShards(train_shards_path_or_url),
+ tarfile_to_samples_nothrow,
+ wds.select(WebdatasetFilter(min_size=960)),
+ wds.shuffle(shuffle_buffer_size),
+ *processing_pipeline,
+ wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate),
+ ]
+
+ num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker
+ num_batches = num_worker_batches * num_workers
+ num_samples = num_batches * global_batch_size
+
+ # each worker is iterating over this
+ self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches)
+ self._train_dataloader = wds.WebLoader(
+ self._train_dataset,
+ batch_size=None,
+ shuffle=False,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ persistent_workers=persistent_workers,
+ )
+ # add meta-data to dataloader instance for convenience
+ self._train_dataloader.num_batches = num_batches
+ self._train_dataloader.num_samples = num_samples
+
+ @property
+ def train_dataset(self):
+ return self._train_dataset
+
+ @property
+ def train_dataloader(self):
+ return self._train_dataloader
+
+
+def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="target"):
+ logger.info("Running validation... ")
+
+ unet = accelerator.unwrap_model(unet)
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_teacher_model,
+ vae=vae,
+ unet=unet,
+ scheduler=LCMScheduler.from_pretrained(args.pretrained_teacher_model, subfolder="scheduler"),
+ revision=args.revision,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ validation_prompts = [
+ "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography",
+ "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
+ ]
+
+ image_logs = []
+
+ for _, prompt in enumerate(validation_prompts):
+ images = []
+ with torch.autocast("cuda"):
+ images = pipeline(
+ prompt=prompt,
+ num_inference_steps=4,
+ num_images_per_prompt=4,
+ generator=generator,
+ ).images
+ image_logs.append({"validation_prompt": prompt, "images": images})
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ formatted_images = []
+ for image in images:
+ formatted_images.append(np.asarray(image))
+
+ formatted_images = np.stack(formatted_images)
+
+ tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({f"validation/{name}": formatted_images})
+ else:
+ logger.warn(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# From LCMScheduler.get_scalings_for_boundary_condition_discrete
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
+ c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
+ c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+# Compare LCMScheduler.step, Step 4
+def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ if prediction_type == "epsilon":
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ pred_x_0 = (sample - sigmas * model_output) / alphas
+ elif prediction_type == "v_prediction":
+ pred_x_0 = alphas[timesteps] * sample - sigmas[timesteps] * model_output
+ else:
+ raise ValueError(f"Prediction type {prediction_type} currently not supported.")
+
+ return pred_x_0
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+@torch.no_grad()
+def update_ema(target_params, source_params, rate=0.99):
+ """
+ Update target parameters to be closer to those of source parameters using
+ an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+# From LatentConsistencyModel.get_guidance_scale_embedding
+def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+
+class DDIMSolver:
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
+ # DDIM sampling parameters
+ step_ratio = timesteps // ddim_timesteps
+
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
+ self.ddim_alpha_cumprods_prev = np.asarray(
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
+ )
+ # convert to torch tensors
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
+
+ def to(self, device):
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
+ return self
+
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ # ----------Model Checkpoint Loading Arguments----------
+ parser.add_argument(
+ "--pretrained_teacher_model",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--teacher_revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
+ )
+ # ----------Training Arguments----------
+ # ----General Training Arguments----
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lcm-xl-distilled",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ # ----Logging----
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ # ----Checkpointing----
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ # ----Image Processing----
+ parser.add_argument(
+ "--train_shards_path_or_url",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--use_fix_crop_and_size",
+ action="store_true",
+ help="Whether or not to use the fixed crop and size for the teacher model.",
+ default=False,
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ # ----Dataloader----
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ # ----Batch Size and Training Steps----
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ # ----Learning Rate----
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ # ----Optimizer (Adam)----
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ # ----Diffusion Training Arguments----
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
+ parser.add_argument(
+ "--w_min",
+ type=float,
+ default=3.0,
+ required=False,
+ help=(
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--w_max",
+ type=float,
+ default=15.0,
+ required=False,
+ help=(
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--num_ddim_timesteps",
+ type=int,
+ default=50,
+ help="The number of timesteps to use for DDIM sampling.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber"],
+ help="The type of loss to use for the LCD loss.",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.001,
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
+ )
+ # ----Exponential Moving Average (EMA)----
+ parser.add_argument(
+ "--ema_decay",
+ type=float,
+ default=0.95,
+ required=False,
+ help="The exponential moving average (EMA) rate or decay factor.",
+ )
+ # ----Mixed Precision----
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cast_teacher_unet",
+ action="store_true",
+ help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.",
+ )
+ # ----Training Optimizations----
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ # ----Distributed Training----
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ # ----------Validation Arguments----------
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=200,
+ help="Run validation every X steps.",
+ )
+ # ----------Huggingface Hub Arguments-----------
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ # ----------Accelerate Arguments----------
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True):
+ prompt_embeds_list = []
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def main(args):
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ token=args.hub_token,
+ private=True,
+ ).repo_id
+
+ # 1. Create the noise scheduler and the desired noise schedule.
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
+ )
+
+ # The scheduler calculates the alpha and sigma schedule for us
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
+ solver = DDIMSolver(
+ noise_scheduler.alphas_cumprod.numpy(),
+ timesteps=noise_scheduler.config.num_train_timesteps,
+ ddim_timesteps=args.num_ddim_timesteps,
+ )
+
+ # 2. Load tokenizers from SD-XL checkpoint.
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_teacher_model, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False
+ )
+
+ # 3. Load text encoders from SD-XL checkpoint.
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_teacher_model, args.teacher_revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_teacher_model, args.teacher_revision, subfolder="text_encoder_2"
+ )
+
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_teacher_model, subfolder="text_encoder_2", revision=args.teacher_revision
+ )
+
+ # 4. Load VAE from SD-XL checkpoint (or more stable VAE)
+ vae_path = (
+ args.pretrained_teacher_model
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.teacher_revision,
+ )
+
+ # 5. Load teacher U-Net from SD-XL checkpoint
+ teacher_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
+ )
+
+ # 6. Freeze teacher vae, text_encoders, and teacher_unet
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ teacher_unet.requires_grad_(False)
+
+ # 8. Create online (`unet`) student U-Nets. This will be updated by the optimizer (e.g. via backpropagation.)
+ # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
+ if teacher_unet.config.time_cond_proj_dim is None:
+ teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
+ unet = UNet2DConditionModel(**teacher_unet.config)
+ # load teacher_unet weights into unet
+ unet.load_state_dict(teacher_unet.state_dict(), strict=False)
+ unet.train()
+
+ # 9. Create target (`ema_unet`) student U-Net parameters. This will be updated via EMA updates (polyak averaging).
+ # Initialize from unet
+ target_unet = UNet2DConditionModel(**teacher_unet.config)
+ target_unet.load_state_dict(unet.state_dict())
+ target_unet.train()
+ target_unet.requires_grad_(False)
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if accelerator.unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
+
+ # 9. Handle mixed precision and device placement
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device)
+ if args.pretrained_vae_model_name_or_path is not None:
+ vae.to(dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+ target_unet.to(accelerator.device)
+ # Move teacher_unet to device, optionally cast to weight_dtype
+ teacher_unet.to(accelerator.device)
+ if args.cast_teacher_unet:
+ teacher_unet.to(dtype=weight_dtype)
+
+ # Also move the alpha and sigma noise schedules to accelerator.device.
+ alpha_schedule = alpha_schedule.to(accelerator.device)
+ sigma_schedule = sigma_schedule.to(accelerator.device)
+ solver = solver.to(accelerator.device)
+
+ # 10. Handle saving and loading of checkpoints
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ target_unet.save_pretrained(os.path.join(output_dir, "unet_target"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target"))
+ target_unet.load_state_dict(load_model.state_dict())
+ target_unet.to(accelerator.device)
+ del load_model
+
+ for i in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # 11. Enable optimizations
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warn(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ teacher_unet.enable_xformers_memory_efficient_attention()
+ target_unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # 12. Optimizer creation
+ optimizer = optimizer_class(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # 13. Dataset creation and data processing
+ # Here, we compute not just the text embeddings but also the additional embeddings
+ # needed for the SD XL UNet to operate.
+ def compute_embeddings(
+ prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True
+ ):
+ target_size = (args.resolution, args.resolution)
+ original_sizes = list(map(list, zip(*original_sizes)))
+ crops_coords_top_left = list(map(list, zip(*crop_coords)))
+
+ original_sizes = torch.tensor(original_sizes, dtype=torch.long)
+ crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long)
+
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train
+ )
+ add_text_embeds = pooled_prompt_embeds
+
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ add_time_ids = list(target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.repeat(len(prompt_batch), 1)
+ add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1)
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ add_text_embeds = add_text_embeds.to(accelerator.device)
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
+
+ dataset = Text2ImageDataset(
+ train_shards_path_or_url=args.train_shards_path_or_url,
+ num_train_examples=args.max_train_samples,
+ per_gpu_batch_size=args.train_batch_size,
+ global_batch_size=args.train_batch_size * accelerator.num_processes,
+ num_workers=args.dataloader_num_workers,
+ resolution=args.resolution,
+ shuffle_buffer_size=1000,
+ pin_memory=True,
+ persistent_workers=True,
+ use_fix_crop_and_size=args.use_fix_crop_and_size,
+ )
+ train_dataloader = dataset.train_dataloader
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ proportion_empty_prompts=0,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ )
+
+ # 14. LR Scheduler creation
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps,
+ num_training_steps=args.max_train_steps,
+ )
+
+ # 15. Prepare for training
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Create uncond embeds for classifier free guidance
+ uncond_prompt_embeds = torch.zeros(args.train_batch_size, 77, 2048).to(accelerator.device)
+ uncond_pooled_prompt_embeds = torch.zeros(args.train_batch_size, 1280).to(accelerator.device)
+
+ # 16. Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num batches each epoch = {train_dataloader.num_batches}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ image, text, orig_size, crop_coords = batch
+
+ image = image.to(accelerator.device, non_blocking=True)
+ encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
+
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = image.to(dtype=weight_dtype)
+ if vae.dtype != weight_dtype:
+ vae.to(dtype=weight_dtype)
+ else:
+ pixel_values = image
+
+ # encode pixel values with batch size of at most 8
+ latents = []
+ for i in range(0, pixel_values.shape[0], 8):
+ latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image t_n ~ U[0, N - k - 1] without bias.
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
+ start_timesteps = solver.ddim_timesteps[index]
+ timesteps = start_timesteps - topk
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
+
+ # 20.4.4. Get boundary scalings for start_timesteps and (end) timesteps.
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(start_timesteps)
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
+ c_skip, c_out = scalings_for_boundary_conditions(timesteps)
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
+
+ # 20.4.5. Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
+
+ # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
+ w = w.reshape(bsz, 1, 1, 1)
+ w = w.to(device=latents.device, dtype=latents.dtype)
+
+ # 20.4.8. Prepare prompt embeds and unet_added_conditions
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+
+ # 20.4.9. Get online LCM prediction on z_{t_{n + k}}, w, c, t_{n + k}
+ noise_pred = unet(
+ noisy_model_input,
+ start_timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+
+ pred_x_0 = predicted_origin(
+ noise_pred,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
+
+ # 20.4.10. Use the ODE solver to predict the kth step in the augmented PF-ODE trajectory after
+ # noisy_latents with both the conditioning embedding c and unconditional embedding 0
+ # Get teacher model prediction on noisy_latents and conditional embedding
+ with torch.no_grad():
+ with torch.autocast("cuda"):
+ cond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=prompt_embeds.to(weight_dtype),
+ added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()},
+ ).sample
+ cond_pred_x0 = predicted_origin(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # Get teacher model prediction on noisy_latents and unconditional embedding
+ uncond_added_conditions = copy.deepcopy(encoded_text)
+ uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds
+ uncond_teacher_output = teacher_unet(
+ noisy_model_input.to(weight_dtype),
+ start_timesteps,
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
+ added_cond_kwargs={k: v.to(weight_dtype) for k, v in uncond_added_conditions.items()},
+ ).sample
+ uncond_pred_x0 = predicted_origin(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 20.4.11. Perform "CFG" to get x_prev estimate (using the LCM paper's CFG formulation)
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
+ pred_noise = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index)
+
+ # 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
+ with torch.no_grad():
+ with torch.autocast("cuda", dtype=weight_dtype):
+ target_noise_pred = target_unet(
+ x_prev.float(),
+ timesteps,
+ timestep_cond=None,
+ encoder_hidden_states=prompt_embeds.float(),
+ added_cond_kwargs=encoded_text,
+ ).sample
+ pred_x_0 = predicted_origin(
+ target_noise_pred,
+ timesteps,
+ x_prev,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ target = c_skip * x_prev + c_out * pred_x_0
+
+ # 20.4.13. Calculate loss
+ if args.loss_type == "l2":
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
+ elif args.loss_type == "huber":
+ loss = torch.mean(
+ torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
+ )
+
+ # 20.4.14. Backpropagate on the online student model (`unet`)
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ # 20.4.15. Make EMA update to target student model parameters
+ update_ema(target_unet.parameters(), unet.parameters(), args.ema_decay)
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, "target")
+ log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "online")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet.save_pretrained(os.path.join(args.output_dir, "unet"))
+
+ target_unet = accelerator.unwrap_model(target_unet)
+ target_unet.save_pretrained(os.path.join(args.output_dir, "unet_target"))
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py
index 76fcb547b6f9..63b6767a6f8f 100644
--- a/examples/controlnet/train_controlnet.py
+++ b/examples/controlnet/train_controlnet.py
@@ -56,7 +56,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py
index 47483883824e..b658f689358d 100644
--- a/examples/controlnet/train_controlnet_flax.py
+++ b/examples/controlnet/train_controlnet_flax.py
@@ -59,7 +59,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py
index 5f745966c9d4..b4fa96dae8ff 100644
--- a/examples/controlnet/train_controlnet_sdxl.py
+++ b/examples/controlnet/train_controlnet_sdxl.py
@@ -58,7 +58,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py
index 894fb39deeb8..d69ce2f28802 100644
--- a/examples/custom_diffusion/train_custom_diffusion.py
+++ b/examples/custom_diffusion/train_custom_diffusion.py
@@ -62,7 +62,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py
index 88d05be4561d..8c103e6204f8 100644
--- a/examples/dreambooth/train_dreambooth.py
+++ b/examples/dreambooth/train_dreambooth.py
@@ -61,7 +61,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py
index d2c0f8697baa..5c37484e86bd 100644
--- a/examples/dreambooth/train_dreambooth_flax.py
+++ b/examples/dreambooth/train_dreambooth_flax.py
@@ -35,7 +35,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py
index 953d8e637d1e..a82d880ff5b1 100644
--- a/examples/dreambooth/train_dreambooth_lora.py
+++ b/examples/dreambooth/train_dreambooth_lora.py
@@ -68,7 +68,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index e8dd6777f32c..002e01b28405 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -58,7 +58,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py
index 58baca312ce2..b9b1c9cc5b3b 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py
@@ -52,7 +52,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
index 288404b4728c..6b503cb29275 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
@@ -55,7 +55,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
index 9ad01357c1f5..bc0a64b42e4b 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
@@ -52,7 +52,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
index 472010320d73..2f968aa8b8b3 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
@@ -46,7 +46,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
index a007d8c74b0c..317e4178c04c 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
@@ -46,7 +46,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
index 799f9fbcb3ac..0e6d06074012 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
@@ -51,7 +51,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
index 783678cd346b..d1c9113bbd9d 100644
--- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py
+++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
@@ -58,7 +58,7 @@
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py
index 89e154ef8825..628a0c9d7d96 100644
--- a/examples/text_to_image/train_text_to_image.py
+++ b/examples/text_to_image/train_text_to_image.py
@@ -53,7 +53,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py
index 64b71b4f83ae..9ebe34555310 100644
--- a/examples/text_to_image/train_text_to_image_flax.py
+++ b/examples/text_to_image/train_text_to_image_flax.py
@@ -33,7 +33,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = logging.getLogger(__name__)
@@ -275,6 +275,7 @@ def main():
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir
)
else:
data_files = {}
diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py
index de4076a2ceaf..78b443d149e8 100644
--- a/examples/text_to_image/train_text_to_image_lora.py
+++ b/examples/text_to_image/train_text_to_image_lora.py
@@ -49,7 +49,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py
index f0d83d55e9bf..1a6ef0c856db 100644
--- a/examples/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -58,7 +58,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
@@ -768,6 +768,7 @@ def load_model_hook(models, input_dir):
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir
)
else:
data_files = {}
diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py
index a385795b1a4f..041464e701cc 100644
--- a/examples/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_sdxl.py
@@ -57,7 +57,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index 55c907663249..8ce998aab1fb 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -79,7 +79,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py
index 938454eecb6e..5de1a8d7c325 100644
--- a/examples/textual_inversion/textual_inversion_flax.py
+++ b/examples/textual_inversion/textual_inversion_flax.py
@@ -56,7 +56,7 @@
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = logging.getLogger(__name__)
diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py
index 99c858778259..6e552c9b3dde 100644
--- a/examples/unconditional_image_generation/train_unconditional.py
+++ b/examples/unconditional_image_generation/train_unconditional.py
@@ -29,7 +29,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
index 48e5b96087de..33de3d3bf777 100644
--- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
@@ -50,7 +50,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
index b1e5abaaa278..62450679f201 100644
--- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
+++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
@@ -51,7 +51,7 @@
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.23.0.dev0")
+check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/scripts/convert_consistency_decoder.py b/scripts/convert_consistency_decoder.py
new file mode 100644
index 000000000000..8e6da07d8c6c
--- /dev/null
+++ b/scripts/convert_consistency_decoder.py
@@ -0,0 +1,1128 @@
+import hashlib
+import math
+import os
+import urllib
+import warnings
+from argparse import ArgumentParser
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from safetensors.torch import load_file as stl
+from tqdm import tqdm
+
+from diffusers import AutoencoderKL, ConsistencyDecoderVAE, DiffusionPipeline, StableDiffusionPipeline, UNet2DModel
+from diffusers.models.embeddings import TimestepEmbedding
+from diffusers.models.unet_2d_blocks import ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, UNetMidBlock2D
+from diffusers.models.vae import Encoder
+
+
+args = ArgumentParser()
+args.add_argument("--save_pretrained", required=False, default=None, type=str)
+args.add_argument("--test_image", required=True, type=str)
+args = args.parse_args()
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895 """
+ res = arr[timesteps].float()
+ dims_to_append = len(broadcast_shape) - len(res.shape)
+ return res[(...,) + (None,) * dims_to_append]
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L45
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return torch.tensor(betas)
+
+
+def _download(url: str, root: str):
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ expected_sha256 = url.split("/")[-2]
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(
+ total=int(source.info().get("Content-Length")),
+ ncols=80,
+ unit="iB",
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
+ raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+class ConsistencyDecoder:
+ def __init__(self, device="cuda:0", download_root=os.path.expanduser("~/.cache/clip")):
+ self.n_distilled_steps = 64
+ download_target = _download(
+ "https://openaipublic.azureedge.net/diff-vae/c9cebd3132dd9c42936d803e33424145a748843c8f716c0814838bdc8a2fe7cb/decoder.pt",
+ download_root,
+ )
+ self.ckpt = torch.jit.load(download_target).to(device)
+ self.device = device
+ sigma_data = 0.5
+ betas = betas_for_alpha_bar(1024, lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2).to(device)
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
+ sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
+ sigmas = torch.sqrt(1.0 / alphas_cumprod - 1)
+ self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2)
+ self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5
+ self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5
+
+ @staticmethod
+ def round_timesteps(timesteps, total_timesteps, n_distilled_steps, truncate_start=True):
+ with torch.no_grad():
+ space = torch.div(total_timesteps, n_distilled_steps, rounding_mode="floor")
+ rounded_timesteps = (torch.div(timesteps, space, rounding_mode="floor") + 1) * space
+ if truncate_start:
+ rounded_timesteps[rounded_timesteps == total_timesteps] -= space
+ else:
+ rounded_timesteps[rounded_timesteps == total_timesteps] -= space
+ rounded_timesteps[rounded_timesteps == 0] += space
+ return rounded_timesteps
+
+ @staticmethod
+ def ldm_transform_latent(z, extra_scale_factor=1):
+ channel_means = [0.38862467, 0.02253063, 0.07381133, -0.0171294]
+ channel_stds = [0.9654121, 1.0440036, 0.76147926, 0.77022034]
+
+ if len(z.shape) != 4:
+ raise ValueError()
+
+ z = z * 0.18215
+ channels = [z[:, i] for i in range(z.shape[1])]
+
+ channels = [extra_scale_factor * (c - channel_means[i]) / channel_stds[i] for i, c in enumerate(channels)]
+ return torch.stack(channels, dim=1)
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ features: torch.Tensor,
+ schedule=[1.0, 0.5],
+ generator=None,
+ ):
+ features = self.ldm_transform_latent(features)
+ ts = self.round_timesteps(
+ torch.arange(0, 1024),
+ 1024,
+ self.n_distilled_steps,
+ truncate_start=False,
+ )
+ shape = (
+ features.size(0),
+ 3,
+ 8 * features.size(2),
+ 8 * features.size(3),
+ )
+ x_start = torch.zeros(shape, device=features.device, dtype=features.dtype)
+ schedule_timesteps = [int((1024 - 1) * s) for s in schedule]
+ for i in schedule_timesteps:
+ t = ts[i].item()
+ t_ = torch.tensor([t] * features.shape[0]).to(self.device)
+ # noise = torch.randn_like(x_start)
+ noise = torch.randn(x_start.shape, dtype=x_start.dtype, generator=generator).to(device=x_start.device)
+ x_start = (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t_, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t_, x_start.shape) * noise
+ )
+ c_in = _extract_into_tensor(self.c_in, t_, x_start.shape)
+
+ import torch.nn.functional as F
+
+ from diffusers import UNet2DModel
+
+ if isinstance(self.ckpt, UNet2DModel):
+ input = torch.concat([c_in * x_start, F.upsample_nearest(features, scale_factor=8)], dim=1)
+ model_output = self.ckpt(input, t_).sample
+ else:
+ model_output = self.ckpt(c_in * x_start, t_, features=features)
+
+ B, C = x_start.shape[:2]
+ model_output, _ = torch.split(model_output, C, dim=1)
+ pred_xstart = (
+ _extract_into_tensor(self.c_out, t_, x_start.shape) * model_output
+ + _extract_into_tensor(self.c_skip, t_, x_start.shape) * x_start
+ ).clamp(-1, 1)
+ x_start = pred_xstart
+ return x_start
+
+
+def save_image(image, name):
+ import numpy as np
+ from PIL import Image
+
+ image = image[0].cpu().numpy()
+ image = (image + 1.0) * 127.5
+ image = image.clip(0, 255).astype(np.uint8)
+ image = Image.fromarray(image.transpose(1, 2, 0))
+ image.save(name)
+
+
+def load_image(uri, size=None, center_crop=False):
+ import numpy as np
+ from PIL import Image
+
+ image = Image.open(uri)
+ if center_crop:
+ image = image.crop(
+ (
+ (image.width - min(image.width, image.height)) // 2,
+ (image.height - min(image.width, image.height)) // 2,
+ (image.width + min(image.width, image.height)) // 2,
+ (image.height + min(image.width, image.height)) // 2,
+ )
+ )
+ if size is not None:
+ image = image.resize(size)
+ image = torch.tensor(np.array(image).transpose(2, 0, 1)).unsqueeze(0).float()
+ image = image / 127.5 - 1.0
+ return image
+
+
+class TimestepEmbedding_(nn.Module):
+ def __init__(self, n_time=1024, n_emb=320, n_out=1280) -> None:
+ super().__init__()
+ self.emb = nn.Embedding(n_time, n_emb)
+ self.f_1 = nn.Linear(n_emb, n_out)
+ self.f_2 = nn.Linear(n_out, n_out)
+
+ def forward(self, x) -> torch.Tensor:
+ x = self.emb(x)
+ x = self.f_1(x)
+ x = F.silu(x)
+ return self.f_2(x)
+
+
+class ImageEmbedding(nn.Module):
+ def __init__(self, in_channels=7, out_channels=320) -> None:
+ super().__init__()
+ self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
+
+ def forward(self, x) -> torch.Tensor:
+ return self.f(x)
+
+
+class ImageUnembedding(nn.Module):
+ def __init__(self, in_channels=320, out_channels=6) -> None:
+ super().__init__()
+ self.gn = nn.GroupNorm(32, in_channels)
+ self.f = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
+
+ def forward(self, x) -> torch.Tensor:
+ return self.f(F.silu(self.gn(x)))
+
+
+class ConvResblock(nn.Module):
+ def __init__(self, in_features=320, out_features=320) -> None:
+ super().__init__()
+ self.f_t = nn.Linear(1280, out_features * 2)
+
+ self.gn_1 = nn.GroupNorm(32, in_features)
+ self.f_1 = nn.Conv2d(in_features, out_features, kernel_size=3, padding=1)
+
+ self.gn_2 = nn.GroupNorm(32, out_features)
+ self.f_2 = nn.Conv2d(out_features, out_features, kernel_size=3, padding=1)
+
+ skip_conv = in_features != out_features
+ self.f_s = nn.Conv2d(in_features, out_features, kernel_size=1, padding=0) if skip_conv else nn.Identity()
+
+ def forward(self, x, t):
+ x_skip = x
+ t = self.f_t(F.silu(t))
+ t = t.chunk(2, dim=1)
+ t_1 = t[0].unsqueeze(dim=2).unsqueeze(dim=3) + 1
+ t_2 = t[1].unsqueeze(dim=2).unsqueeze(dim=3)
+
+ gn_1 = F.silu(self.gn_1(x))
+ f_1 = self.f_1(gn_1)
+
+ gn_2 = self.gn_2(f_1)
+
+ return self.f_s(x_skip) + self.f_2(F.silu(gn_2 * t_1 + t_2))
+
+
+# Also ConvResblock
+class Downsample(nn.Module):
+ def __init__(self, in_channels=320) -> None:
+ super().__init__()
+ self.f_t = nn.Linear(1280, in_channels * 2)
+
+ self.gn_1 = nn.GroupNorm(32, in_channels)
+ self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+ self.gn_2 = nn.GroupNorm(32, in_channels)
+
+ self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+
+ def forward(self, x, t) -> torch.Tensor:
+ x_skip = x
+
+ t = self.f_t(F.silu(t))
+ t_1, t_2 = t.chunk(2, dim=1)
+ t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
+ t_2 = t_2.unsqueeze(2).unsqueeze(3)
+
+ gn_1 = F.silu(self.gn_1(x))
+ avg_pool2d = F.avg_pool2d(gn_1, kernel_size=(2, 2), stride=None)
+
+ f_1 = self.f_1(avg_pool2d)
+ gn_2 = self.gn_2(f_1)
+
+ f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
+
+ return f_2 + F.avg_pool2d(x_skip, kernel_size=(2, 2), stride=None)
+
+
+# Also ConvResblock
+class Upsample(nn.Module):
+ def __init__(self, in_channels=1024) -> None:
+ super().__init__()
+ self.f_t = nn.Linear(1280, in_channels * 2)
+
+ self.gn_1 = nn.GroupNorm(32, in_channels)
+ self.f_1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+ self.gn_2 = nn.GroupNorm(32, in_channels)
+
+ self.f_2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
+
+ def forward(self, x, t) -> torch.Tensor:
+ x_skip = x
+
+ t = self.f_t(F.silu(t))
+ t_1, t_2 = t.chunk(2, dim=1)
+ t_1 = t_1.unsqueeze(2).unsqueeze(3) + 1
+ t_2 = t_2.unsqueeze(2).unsqueeze(3)
+
+ gn_1 = F.silu(self.gn_1(x))
+ upsample = F.upsample_nearest(gn_1, scale_factor=2)
+ f_1 = self.f_1(upsample)
+ gn_2 = self.gn_2(f_1)
+
+ f_2 = self.f_2(F.silu(t_2 + (t_1 * gn_2)))
+
+ return f_2 + F.upsample_nearest(x_skip, scale_factor=2)
+
+
+class ConvUNetVAE(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.embed_image = ImageEmbedding()
+ self.embed_time = TimestepEmbedding_()
+
+ down_0 = nn.ModuleList(
+ [
+ ConvResblock(320, 320),
+ ConvResblock(320, 320),
+ ConvResblock(320, 320),
+ Downsample(320),
+ ]
+ )
+ down_1 = nn.ModuleList(
+ [
+ ConvResblock(320, 640),
+ ConvResblock(640, 640),
+ ConvResblock(640, 640),
+ Downsample(640),
+ ]
+ )
+ down_2 = nn.ModuleList(
+ [
+ ConvResblock(640, 1024),
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ Downsample(1024),
+ ]
+ )
+ down_3 = nn.ModuleList(
+ [
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ ]
+ )
+ self.down = nn.ModuleList(
+ [
+ down_0,
+ down_1,
+ down_2,
+ down_3,
+ ]
+ )
+
+ self.mid = nn.ModuleList(
+ [
+ ConvResblock(1024, 1024),
+ ConvResblock(1024, 1024),
+ ]
+ )
+
+ up_3 = nn.ModuleList(
+ [
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ Upsample(1024),
+ ]
+ )
+ up_2 = nn.ModuleList(
+ [
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 * 2, 1024),
+ ConvResblock(1024 + 640, 1024),
+ Upsample(1024),
+ ]
+ )
+ up_1 = nn.ModuleList(
+ [
+ ConvResblock(1024 + 640, 640),
+ ConvResblock(640 * 2, 640),
+ ConvResblock(640 * 2, 640),
+ ConvResblock(320 + 640, 640),
+ Upsample(640),
+ ]
+ )
+ up_0 = nn.ModuleList(
+ [
+ ConvResblock(320 + 640, 320),
+ ConvResblock(320 * 2, 320),
+ ConvResblock(320 * 2, 320),
+ ConvResblock(320 * 2, 320),
+ ]
+ )
+ self.up = nn.ModuleList(
+ [
+ up_0,
+ up_1,
+ up_2,
+ up_3,
+ ]
+ )
+
+ self.output = ImageUnembedding()
+
+ def forward(self, x, t, features) -> torch.Tensor:
+ converted = hasattr(self, "converted") and self.converted
+
+ x = torch.cat([x, F.upsample_nearest(features, scale_factor=8)], dim=1)
+
+ if converted:
+ t = self.time_embedding(self.time_proj(t))
+ else:
+ t = self.embed_time(t)
+
+ x = self.embed_image(x)
+
+ skips = [x]
+ for i, down in enumerate(self.down):
+ if converted and i in [0, 1, 2, 3]:
+ x, skips_ = down(x, t)
+ for skip in skips_:
+ skips.append(skip)
+ else:
+ for block in down:
+ x = block(x, t)
+ skips.append(x)
+ print(x.float().abs().sum())
+
+ if converted:
+ x = self.mid(x, t)
+ else:
+ for i in range(2):
+ x = self.mid[i](x, t)
+ print(x.float().abs().sum())
+
+ for i, up in enumerate(self.up[::-1]):
+ if converted and i in [0, 1, 2, 3]:
+ skip_4 = skips.pop()
+ skip_3 = skips.pop()
+ skip_2 = skips.pop()
+ skip_1 = skips.pop()
+ skips_ = (skip_1, skip_2, skip_3, skip_4)
+ x = up(x, skips_, t)
+ else:
+ for block in up:
+ if isinstance(block, ConvResblock):
+ x = torch.concat([x, skips.pop()], dim=1)
+ x = block(x, t)
+
+ return self.output(x)
+
+
+def rename_state_dict_key(k):
+ k = k.replace("blocks.", "")
+ for i in range(5):
+ k = k.replace(f"down_{i}_", f"down.{i}.")
+ k = k.replace(f"conv_{i}.", f"{i}.")
+ k = k.replace(f"up_{i}_", f"up.{i}.")
+ k = k.replace(f"mid_{i}", f"mid.{i}")
+ k = k.replace("upsamp.", "4.")
+ k = k.replace("downsamp.", "3.")
+ k = k.replace("f_t.w", "f_t.weight").replace("f_t.b", "f_t.bias")
+ k = k.replace("f_1.w", "f_1.weight").replace("f_1.b", "f_1.bias")
+ k = k.replace("f_2.w", "f_2.weight").replace("f_2.b", "f_2.bias")
+ k = k.replace("f_s.w", "f_s.weight").replace("f_s.b", "f_s.bias")
+ k = k.replace("f.w", "f.weight").replace("f.b", "f.bias")
+ k = k.replace("gn_1.g", "gn_1.weight").replace("gn_1.b", "gn_1.bias")
+ k = k.replace("gn_2.g", "gn_2.weight").replace("gn_2.b", "gn_2.bias")
+ k = k.replace("gn.g", "gn.weight").replace("gn.b", "gn.bias")
+ return k
+
+
+def rename_state_dict(sd, embedding):
+ sd = {rename_state_dict_key(k): v for k, v in sd.items()}
+ sd["embed_time.emb.weight"] = embedding["weight"]
+ return sd
+
+
+# encode with stable diffusion vae
+pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+pipe.vae.cuda()
+
+# construct original decoder with jitted model
+decoder_consistency = ConsistencyDecoder(device="cuda:0")
+
+# construct UNet code, overwrite the decoder with conv_unet_vae
+model = ConvUNetVAE()
+model.load_state_dict(
+ rename_state_dict(
+ stl("consistency_decoder.safetensors"),
+ stl("embedding.safetensors"),
+ )
+)
+model = model.cuda()
+
+decoder_consistency.ckpt = model
+
+image = load_image(args.test_image, size=(256, 256), center_crop=True)
+latent = pipe.vae.encode(image.half().cuda()).latent_dist.sample()
+
+# decode with gan
+sample_gan = pipe.vae.decode(latent).sample.detach()
+save_image(sample_gan, "gan.png")
+
+# decode with conv_unet_vae
+sample_consistency_orig = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
+save_image(sample_consistency_orig, "con_orig.png")
+
+
+########### conversion
+
+print("CONVERSION")
+
+print("DOWN BLOCK ONE")
+
+block_one_sd_orig = model.down[0].state_dict()
+block_one_sd_new = {}
+
+for i in range(3):
+ block_one_sd_new[f"resnets.{i}.norm1.weight"] = block_one_sd_orig.pop(f"{i}.gn_1.weight")
+ block_one_sd_new[f"resnets.{i}.norm1.bias"] = block_one_sd_orig.pop(f"{i}.gn_1.bias")
+ block_one_sd_new[f"resnets.{i}.conv1.weight"] = block_one_sd_orig.pop(f"{i}.f_1.weight")
+ block_one_sd_new[f"resnets.{i}.conv1.bias"] = block_one_sd_orig.pop(f"{i}.f_1.bias")
+ block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_one_sd_orig.pop(f"{i}.f_t.weight")
+ block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_one_sd_orig.pop(f"{i}.f_t.bias")
+ block_one_sd_new[f"resnets.{i}.norm2.weight"] = block_one_sd_orig.pop(f"{i}.gn_2.weight")
+ block_one_sd_new[f"resnets.{i}.norm2.bias"] = block_one_sd_orig.pop(f"{i}.gn_2.bias")
+ block_one_sd_new[f"resnets.{i}.conv2.weight"] = block_one_sd_orig.pop(f"{i}.f_2.weight")
+ block_one_sd_new[f"resnets.{i}.conv2.bias"] = block_one_sd_orig.pop(f"{i}.f_2.bias")
+
+block_one_sd_new["downsamplers.0.norm1.weight"] = block_one_sd_orig.pop("3.gn_1.weight")
+block_one_sd_new["downsamplers.0.norm1.bias"] = block_one_sd_orig.pop("3.gn_1.bias")
+block_one_sd_new["downsamplers.0.conv1.weight"] = block_one_sd_orig.pop("3.f_1.weight")
+block_one_sd_new["downsamplers.0.conv1.bias"] = block_one_sd_orig.pop("3.f_1.bias")
+block_one_sd_new["downsamplers.0.time_emb_proj.weight"] = block_one_sd_orig.pop("3.f_t.weight")
+block_one_sd_new["downsamplers.0.time_emb_proj.bias"] = block_one_sd_orig.pop("3.f_t.bias")
+block_one_sd_new["downsamplers.0.norm2.weight"] = block_one_sd_orig.pop("3.gn_2.weight")
+block_one_sd_new["downsamplers.0.norm2.bias"] = block_one_sd_orig.pop("3.gn_2.bias")
+block_one_sd_new["downsamplers.0.conv2.weight"] = block_one_sd_orig.pop("3.f_2.weight")
+block_one_sd_new["downsamplers.0.conv2.bias"] = block_one_sd_orig.pop("3.f_2.bias")
+
+assert len(block_one_sd_orig) == 0
+
+block_one = ResnetDownsampleBlock2D(
+ in_channels=320,
+ out_channels=320,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_one.load_state_dict(block_one_sd_new)
+
+print("DOWN BLOCK TWO")
+
+block_two_sd_orig = model.down[1].state_dict()
+block_two_sd_new = {}
+
+for i in range(3):
+ block_two_sd_new[f"resnets.{i}.norm1.weight"] = block_two_sd_orig.pop(f"{i}.gn_1.weight")
+ block_two_sd_new[f"resnets.{i}.norm1.bias"] = block_two_sd_orig.pop(f"{i}.gn_1.bias")
+ block_two_sd_new[f"resnets.{i}.conv1.weight"] = block_two_sd_orig.pop(f"{i}.f_1.weight")
+ block_two_sd_new[f"resnets.{i}.conv1.bias"] = block_two_sd_orig.pop(f"{i}.f_1.bias")
+ block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_two_sd_orig.pop(f"{i}.f_t.weight")
+ block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_two_sd_orig.pop(f"{i}.f_t.bias")
+ block_two_sd_new[f"resnets.{i}.norm2.weight"] = block_two_sd_orig.pop(f"{i}.gn_2.weight")
+ block_two_sd_new[f"resnets.{i}.norm2.bias"] = block_two_sd_orig.pop(f"{i}.gn_2.bias")
+ block_two_sd_new[f"resnets.{i}.conv2.weight"] = block_two_sd_orig.pop(f"{i}.f_2.weight")
+ block_two_sd_new[f"resnets.{i}.conv2.bias"] = block_two_sd_orig.pop(f"{i}.f_2.bias")
+
+ if i == 0:
+ block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_two_sd_orig.pop(f"{i}.f_s.weight")
+ block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_two_sd_orig.pop(f"{i}.f_s.bias")
+
+block_two_sd_new["downsamplers.0.norm1.weight"] = block_two_sd_orig.pop("3.gn_1.weight")
+block_two_sd_new["downsamplers.0.norm1.bias"] = block_two_sd_orig.pop("3.gn_1.bias")
+block_two_sd_new["downsamplers.0.conv1.weight"] = block_two_sd_orig.pop("3.f_1.weight")
+block_two_sd_new["downsamplers.0.conv1.bias"] = block_two_sd_orig.pop("3.f_1.bias")
+block_two_sd_new["downsamplers.0.time_emb_proj.weight"] = block_two_sd_orig.pop("3.f_t.weight")
+block_two_sd_new["downsamplers.0.time_emb_proj.bias"] = block_two_sd_orig.pop("3.f_t.bias")
+block_two_sd_new["downsamplers.0.norm2.weight"] = block_two_sd_orig.pop("3.gn_2.weight")
+block_two_sd_new["downsamplers.0.norm2.bias"] = block_two_sd_orig.pop("3.gn_2.bias")
+block_two_sd_new["downsamplers.0.conv2.weight"] = block_two_sd_orig.pop("3.f_2.weight")
+block_two_sd_new["downsamplers.0.conv2.bias"] = block_two_sd_orig.pop("3.f_2.bias")
+
+assert len(block_two_sd_orig) == 0
+
+block_two = ResnetDownsampleBlock2D(
+ in_channels=320,
+ out_channels=640,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_two.load_state_dict(block_two_sd_new)
+
+print("DOWN BLOCK THREE")
+
+block_three_sd_orig = model.down[2].state_dict()
+block_three_sd_new = {}
+
+for i in range(3):
+ block_three_sd_new[f"resnets.{i}.norm1.weight"] = block_three_sd_orig.pop(f"{i}.gn_1.weight")
+ block_three_sd_new[f"resnets.{i}.norm1.bias"] = block_three_sd_orig.pop(f"{i}.gn_1.bias")
+ block_three_sd_new[f"resnets.{i}.conv1.weight"] = block_three_sd_orig.pop(f"{i}.f_1.weight")
+ block_three_sd_new[f"resnets.{i}.conv1.bias"] = block_three_sd_orig.pop(f"{i}.f_1.bias")
+ block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_three_sd_orig.pop(f"{i}.f_t.weight")
+ block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_three_sd_orig.pop(f"{i}.f_t.bias")
+ block_three_sd_new[f"resnets.{i}.norm2.weight"] = block_three_sd_orig.pop(f"{i}.gn_2.weight")
+ block_three_sd_new[f"resnets.{i}.norm2.bias"] = block_three_sd_orig.pop(f"{i}.gn_2.bias")
+ block_three_sd_new[f"resnets.{i}.conv2.weight"] = block_three_sd_orig.pop(f"{i}.f_2.weight")
+ block_three_sd_new[f"resnets.{i}.conv2.bias"] = block_three_sd_orig.pop(f"{i}.f_2.bias")
+
+ if i == 0:
+ block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = block_three_sd_orig.pop(f"{i}.f_s.weight")
+ block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = block_three_sd_orig.pop(f"{i}.f_s.bias")
+
+block_three_sd_new["downsamplers.0.norm1.weight"] = block_three_sd_orig.pop("3.gn_1.weight")
+block_three_sd_new["downsamplers.0.norm1.bias"] = block_three_sd_orig.pop("3.gn_1.bias")
+block_three_sd_new["downsamplers.0.conv1.weight"] = block_three_sd_orig.pop("3.f_1.weight")
+block_three_sd_new["downsamplers.0.conv1.bias"] = block_three_sd_orig.pop("3.f_1.bias")
+block_three_sd_new["downsamplers.0.time_emb_proj.weight"] = block_three_sd_orig.pop("3.f_t.weight")
+block_three_sd_new["downsamplers.0.time_emb_proj.bias"] = block_three_sd_orig.pop("3.f_t.bias")
+block_three_sd_new["downsamplers.0.norm2.weight"] = block_three_sd_orig.pop("3.gn_2.weight")
+block_three_sd_new["downsamplers.0.norm2.bias"] = block_three_sd_orig.pop("3.gn_2.bias")
+block_three_sd_new["downsamplers.0.conv2.weight"] = block_three_sd_orig.pop("3.f_2.weight")
+block_three_sd_new["downsamplers.0.conv2.bias"] = block_three_sd_orig.pop("3.f_2.bias")
+
+assert len(block_three_sd_orig) == 0
+
+block_three = ResnetDownsampleBlock2D(
+ in_channels=640,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_three.load_state_dict(block_three_sd_new)
+
+print("DOWN BLOCK FOUR")
+
+block_four_sd_orig = model.down[3].state_dict()
+block_four_sd_new = {}
+
+for i in range(3):
+ block_four_sd_new[f"resnets.{i}.norm1.weight"] = block_four_sd_orig.pop(f"{i}.gn_1.weight")
+ block_four_sd_new[f"resnets.{i}.norm1.bias"] = block_four_sd_orig.pop(f"{i}.gn_1.bias")
+ block_four_sd_new[f"resnets.{i}.conv1.weight"] = block_four_sd_orig.pop(f"{i}.f_1.weight")
+ block_four_sd_new[f"resnets.{i}.conv1.bias"] = block_four_sd_orig.pop(f"{i}.f_1.bias")
+ block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = block_four_sd_orig.pop(f"{i}.f_t.weight")
+ block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = block_four_sd_orig.pop(f"{i}.f_t.bias")
+ block_four_sd_new[f"resnets.{i}.norm2.weight"] = block_four_sd_orig.pop(f"{i}.gn_2.weight")
+ block_four_sd_new[f"resnets.{i}.norm2.bias"] = block_four_sd_orig.pop(f"{i}.gn_2.bias")
+ block_four_sd_new[f"resnets.{i}.conv2.weight"] = block_four_sd_orig.pop(f"{i}.f_2.weight")
+ block_four_sd_new[f"resnets.{i}.conv2.bias"] = block_four_sd_orig.pop(f"{i}.f_2.bias")
+
+assert len(block_four_sd_orig) == 0
+
+block_four = ResnetDownsampleBlock2D(
+ in_channels=1024,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=3,
+ add_downsample=False,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+block_four.load_state_dict(block_four_sd_new)
+
+
+print("MID BLOCK 1")
+
+mid_block_one_sd_orig = model.mid.state_dict()
+mid_block_one_sd_new = {}
+
+for i in range(2):
+ mid_block_one_sd_new[f"resnets.{i}.norm1.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.weight")
+ mid_block_one_sd_new[f"resnets.{i}.norm1.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_1.bias")
+ mid_block_one_sd_new[f"resnets.{i}.conv1.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_1.weight")
+ mid_block_one_sd_new[f"resnets.{i}.conv1.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_1.bias")
+ mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_t.weight")
+ mid_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_t.bias")
+ mid_block_one_sd_new[f"resnets.{i}.norm2.weight"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.weight")
+ mid_block_one_sd_new[f"resnets.{i}.norm2.bias"] = mid_block_one_sd_orig.pop(f"{i}.gn_2.bias")
+ mid_block_one_sd_new[f"resnets.{i}.conv2.weight"] = mid_block_one_sd_orig.pop(f"{i}.f_2.weight")
+ mid_block_one_sd_new[f"resnets.{i}.conv2.bias"] = mid_block_one_sd_orig.pop(f"{i}.f_2.bias")
+
+assert len(mid_block_one_sd_orig) == 0
+
+mid_block_one = UNetMidBlock2D(
+ in_channels=1024,
+ temb_channels=1280,
+ num_layers=1,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+ add_attention=False,
+)
+
+mid_block_one.load_state_dict(mid_block_one_sd_new)
+
+print("UP BLOCK ONE")
+
+up_block_one_sd_orig = model.up[-1].state_dict()
+up_block_one_sd_new = {}
+
+for i in range(4):
+ up_block_one_sd_new[f"resnets.{i}.norm1.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_one_sd_new[f"resnets.{i}.norm1.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_one_sd_new[f"resnets.{i}.conv1.weight"] = up_block_one_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_one_sd_new[f"resnets.{i}.conv1.bias"] = up_block_one_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_one_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_one_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_one_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_one_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_one_sd_new[f"resnets.{i}.norm2.weight"] = up_block_one_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_one_sd_new[f"resnets.{i}.norm2.bias"] = up_block_one_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_one_sd_new[f"resnets.{i}.conv2.weight"] = up_block_one_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_one_sd_new[f"resnets.{i}.conv2.bias"] = up_block_one_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_one_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_one_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_one_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_one_sd_orig.pop(f"{i}.f_s.bias")
+
+up_block_one_sd_new["upsamplers.0.norm1.weight"] = up_block_one_sd_orig.pop("4.gn_1.weight")
+up_block_one_sd_new["upsamplers.0.norm1.bias"] = up_block_one_sd_orig.pop("4.gn_1.bias")
+up_block_one_sd_new["upsamplers.0.conv1.weight"] = up_block_one_sd_orig.pop("4.f_1.weight")
+up_block_one_sd_new["upsamplers.0.conv1.bias"] = up_block_one_sd_orig.pop("4.f_1.bias")
+up_block_one_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_one_sd_orig.pop("4.f_t.weight")
+up_block_one_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_one_sd_orig.pop("4.f_t.bias")
+up_block_one_sd_new["upsamplers.0.norm2.weight"] = up_block_one_sd_orig.pop("4.gn_2.weight")
+up_block_one_sd_new["upsamplers.0.norm2.bias"] = up_block_one_sd_orig.pop("4.gn_2.bias")
+up_block_one_sd_new["upsamplers.0.conv2.weight"] = up_block_one_sd_orig.pop("4.f_2.weight")
+up_block_one_sd_new["upsamplers.0.conv2.bias"] = up_block_one_sd_orig.pop("4.f_2.bias")
+
+assert len(up_block_one_sd_orig) == 0
+
+up_block_one = ResnetUpsampleBlock2D(
+ in_channels=1024,
+ prev_output_channel=1024,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_one.load_state_dict(up_block_one_sd_new)
+
+print("UP BLOCK TWO")
+
+up_block_two_sd_orig = model.up[-2].state_dict()
+up_block_two_sd_new = {}
+
+for i in range(4):
+ up_block_two_sd_new[f"resnets.{i}.norm1.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_two_sd_new[f"resnets.{i}.norm1.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_two_sd_new[f"resnets.{i}.conv1.weight"] = up_block_two_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_two_sd_new[f"resnets.{i}.conv1.bias"] = up_block_two_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_two_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_two_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_two_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_two_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_two_sd_new[f"resnets.{i}.norm2.weight"] = up_block_two_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_two_sd_new[f"resnets.{i}.norm2.bias"] = up_block_two_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_two_sd_new[f"resnets.{i}.conv2.weight"] = up_block_two_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_two_sd_new[f"resnets.{i}.conv2.bias"] = up_block_two_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_two_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_two_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_two_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_two_sd_orig.pop(f"{i}.f_s.bias")
+
+up_block_two_sd_new["upsamplers.0.norm1.weight"] = up_block_two_sd_orig.pop("4.gn_1.weight")
+up_block_two_sd_new["upsamplers.0.norm1.bias"] = up_block_two_sd_orig.pop("4.gn_1.bias")
+up_block_two_sd_new["upsamplers.0.conv1.weight"] = up_block_two_sd_orig.pop("4.f_1.weight")
+up_block_two_sd_new["upsamplers.0.conv1.bias"] = up_block_two_sd_orig.pop("4.f_1.bias")
+up_block_two_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_two_sd_orig.pop("4.f_t.weight")
+up_block_two_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_two_sd_orig.pop("4.f_t.bias")
+up_block_two_sd_new["upsamplers.0.norm2.weight"] = up_block_two_sd_orig.pop("4.gn_2.weight")
+up_block_two_sd_new["upsamplers.0.norm2.bias"] = up_block_two_sd_orig.pop("4.gn_2.bias")
+up_block_two_sd_new["upsamplers.0.conv2.weight"] = up_block_two_sd_orig.pop("4.f_2.weight")
+up_block_two_sd_new["upsamplers.0.conv2.bias"] = up_block_two_sd_orig.pop("4.f_2.bias")
+
+assert len(up_block_two_sd_orig) == 0
+
+up_block_two = ResnetUpsampleBlock2D(
+ in_channels=640,
+ prev_output_channel=1024,
+ out_channels=1024,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_two.load_state_dict(up_block_two_sd_new)
+
+print("UP BLOCK THREE")
+
+up_block_three_sd_orig = model.up[-3].state_dict()
+up_block_three_sd_new = {}
+
+for i in range(4):
+ up_block_three_sd_new[f"resnets.{i}.norm1.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_three_sd_new[f"resnets.{i}.norm1.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_three_sd_new[f"resnets.{i}.conv1.weight"] = up_block_three_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_three_sd_new[f"resnets.{i}.conv1.bias"] = up_block_three_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_three_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_three_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_three_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_three_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_three_sd_new[f"resnets.{i}.norm2.weight"] = up_block_three_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_three_sd_new[f"resnets.{i}.norm2.bias"] = up_block_three_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_three_sd_new[f"resnets.{i}.conv2.weight"] = up_block_three_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_three_sd_new[f"resnets.{i}.conv2.bias"] = up_block_three_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_three_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_three_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_three_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_three_sd_orig.pop(f"{i}.f_s.bias")
+
+up_block_three_sd_new["upsamplers.0.norm1.weight"] = up_block_three_sd_orig.pop("4.gn_1.weight")
+up_block_three_sd_new["upsamplers.0.norm1.bias"] = up_block_three_sd_orig.pop("4.gn_1.bias")
+up_block_three_sd_new["upsamplers.0.conv1.weight"] = up_block_three_sd_orig.pop("4.f_1.weight")
+up_block_three_sd_new["upsamplers.0.conv1.bias"] = up_block_three_sd_orig.pop("4.f_1.bias")
+up_block_three_sd_new["upsamplers.0.time_emb_proj.weight"] = up_block_three_sd_orig.pop("4.f_t.weight")
+up_block_three_sd_new["upsamplers.0.time_emb_proj.bias"] = up_block_three_sd_orig.pop("4.f_t.bias")
+up_block_three_sd_new["upsamplers.0.norm2.weight"] = up_block_three_sd_orig.pop("4.gn_2.weight")
+up_block_three_sd_new["upsamplers.0.norm2.bias"] = up_block_three_sd_orig.pop("4.gn_2.bias")
+up_block_three_sd_new["upsamplers.0.conv2.weight"] = up_block_three_sd_orig.pop("4.f_2.weight")
+up_block_three_sd_new["upsamplers.0.conv2.bias"] = up_block_three_sd_orig.pop("4.f_2.bias")
+
+assert len(up_block_three_sd_orig) == 0
+
+up_block_three = ResnetUpsampleBlock2D(
+ in_channels=320,
+ prev_output_channel=1024,
+ out_channels=640,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=True,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_three.load_state_dict(up_block_three_sd_new)
+
+print("UP BLOCK FOUR")
+
+up_block_four_sd_orig = model.up[-4].state_dict()
+up_block_four_sd_new = {}
+
+for i in range(4):
+ up_block_four_sd_new[f"resnets.{i}.norm1.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_1.weight")
+ up_block_four_sd_new[f"resnets.{i}.norm1.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_1.bias")
+ up_block_four_sd_new[f"resnets.{i}.conv1.weight"] = up_block_four_sd_orig.pop(f"{i}.f_1.weight")
+ up_block_four_sd_new[f"resnets.{i}.conv1.bias"] = up_block_four_sd_orig.pop(f"{i}.f_1.bias")
+ up_block_four_sd_new[f"resnets.{i}.time_emb_proj.weight"] = up_block_four_sd_orig.pop(f"{i}.f_t.weight")
+ up_block_four_sd_new[f"resnets.{i}.time_emb_proj.bias"] = up_block_four_sd_orig.pop(f"{i}.f_t.bias")
+ up_block_four_sd_new[f"resnets.{i}.norm2.weight"] = up_block_four_sd_orig.pop(f"{i}.gn_2.weight")
+ up_block_four_sd_new[f"resnets.{i}.norm2.bias"] = up_block_four_sd_orig.pop(f"{i}.gn_2.bias")
+ up_block_four_sd_new[f"resnets.{i}.conv2.weight"] = up_block_four_sd_orig.pop(f"{i}.f_2.weight")
+ up_block_four_sd_new[f"resnets.{i}.conv2.bias"] = up_block_four_sd_orig.pop(f"{i}.f_2.bias")
+ up_block_four_sd_new[f"resnets.{i}.conv_shortcut.weight"] = up_block_four_sd_orig.pop(f"{i}.f_s.weight")
+ up_block_four_sd_new[f"resnets.{i}.conv_shortcut.bias"] = up_block_four_sd_orig.pop(f"{i}.f_s.bias")
+
+assert len(up_block_four_sd_orig) == 0
+
+up_block_four = ResnetUpsampleBlock2D(
+ in_channels=320,
+ prev_output_channel=640,
+ out_channels=320,
+ temb_channels=1280,
+ num_layers=4,
+ add_upsample=False,
+ resnet_time_scale_shift="scale_shift",
+ resnet_eps=1e-5,
+)
+
+up_block_four.load_state_dict(up_block_four_sd_new)
+
+print("initial projection (conv_in)")
+
+conv_in_sd_orig = model.embed_image.state_dict()
+conv_in_sd_new = {}
+
+conv_in_sd_new["weight"] = conv_in_sd_orig.pop("f.weight")
+conv_in_sd_new["bias"] = conv_in_sd_orig.pop("f.bias")
+
+assert len(conv_in_sd_orig) == 0
+
+block_out_channels = [320, 640, 1024, 1024]
+
+in_channels = 7
+conv_in_kernel = 3
+conv_in_padding = (conv_in_kernel - 1) // 2
+conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding)
+
+conv_in.load_state_dict(conv_in_sd_new)
+
+print("out projection (conv_out) (conv_norm_out)")
+out_channels = 6
+norm_num_groups = 32
+norm_eps = 1e-5
+act_fn = "silu"
+conv_out_kernel = 3
+conv_out_padding = (conv_out_kernel - 1) // 2
+conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+# uses torch.functional in orig
+# conv_act = get_activation(act_fn)
+conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding)
+
+conv_norm_out.load_state_dict(model.output.gn.state_dict())
+conv_out.load_state_dict(model.output.f.state_dict())
+
+print("timestep projection (time_proj) (time_embedding)")
+
+f1_sd = model.embed_time.f_1.state_dict()
+f2_sd = model.embed_time.f_2.state_dict()
+
+time_embedding_sd = {
+ "linear_1.weight": f1_sd.pop("weight"),
+ "linear_1.bias": f1_sd.pop("bias"),
+ "linear_2.weight": f2_sd.pop("weight"),
+ "linear_2.bias": f2_sd.pop("bias"),
+}
+
+assert len(f1_sd) == 0
+assert len(f2_sd) == 0
+
+time_embedding_type = "learned"
+num_train_timesteps = 1024
+time_embedding_dim = 1280
+
+time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
+timestep_input_dim = block_out_channels[0]
+
+time_embedding = TimestepEmbedding(timestep_input_dim, time_embedding_dim)
+
+time_proj.load_state_dict(model.embed_time.emb.state_dict())
+time_embedding.load_state_dict(time_embedding_sd)
+
+print("CONVERT")
+
+time_embedding.to("cuda")
+time_proj.to("cuda")
+conv_in.to("cuda")
+
+block_one.to("cuda")
+block_two.to("cuda")
+block_three.to("cuda")
+block_four.to("cuda")
+
+mid_block_one.to("cuda")
+
+up_block_one.to("cuda")
+up_block_two.to("cuda")
+up_block_three.to("cuda")
+up_block_four.to("cuda")
+
+conv_norm_out.to("cuda")
+conv_out.to("cuda")
+
+model.time_proj = time_proj
+model.time_embedding = time_embedding
+model.embed_image = conv_in
+
+model.down[0] = block_one
+model.down[1] = block_two
+model.down[2] = block_three
+model.down[3] = block_four
+
+model.mid = mid_block_one
+
+model.up[-1] = up_block_one
+model.up[-2] = up_block_two
+model.up[-3] = up_block_three
+model.up[-4] = up_block_four
+
+model.output.gn = conv_norm_out
+model.output.f = conv_out
+
+model.converted = True
+
+sample_consistency_new = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
+save_image(sample_consistency_new, "con_new.png")
+
+assert (sample_consistency_orig == sample_consistency_new).all()
+
+print("making unet")
+
+unet = UNet2DModel(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ down_block_types=(
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ ),
+ up_block_types=(
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ ),
+ block_out_channels=block_out_channels,
+ layers_per_block=3,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ resnet_time_scale_shift="scale_shift",
+ time_embedding_type="learned",
+ num_train_timesteps=num_train_timesteps,
+ add_attention=False,
+)
+
+unet_state_dict = {}
+
+
+def add_state_dict(prefix, mod):
+ for k, v in mod.state_dict().items():
+ unet_state_dict[f"{prefix}.{k}"] = v
+
+
+add_state_dict("conv_in", conv_in)
+add_state_dict("time_proj", time_proj)
+add_state_dict("time_embedding", time_embedding)
+add_state_dict("down_blocks.0", block_one)
+add_state_dict("down_blocks.1", block_two)
+add_state_dict("down_blocks.2", block_three)
+add_state_dict("down_blocks.3", block_four)
+add_state_dict("mid_block", mid_block_one)
+add_state_dict("up_blocks.0", up_block_one)
+add_state_dict("up_blocks.1", up_block_two)
+add_state_dict("up_blocks.2", up_block_three)
+add_state_dict("up_blocks.3", up_block_four)
+add_state_dict("conv_norm_out", conv_norm_out)
+add_state_dict("conv_out", conv_out)
+
+unet.load_state_dict(unet_state_dict)
+
+print("running with diffusers unet")
+
+unet.to("cuda")
+
+decoder_consistency.ckpt = unet
+
+sample_consistency_new_2 = decoder_consistency(latent, generator=torch.Generator("cpu").manual_seed(0))
+save_image(sample_consistency_new_2, "con_new_2.png")
+
+assert (sample_consistency_orig == sample_consistency_new_2).all()
+
+print("running with diffusers model")
+
+Encoder.old_constructor = Encoder.__init__
+
+
+def new_constructor(self, **kwargs):
+ self.old_constructor(**kwargs)
+ self.constructor_arguments = kwargs
+
+
+Encoder.__init__ = new_constructor
+
+
+vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
+consistency_vae = ConsistencyDecoderVAE(
+ encoder_args=vae.encoder.constructor_arguments,
+ decoder_args=unet.config,
+ scaling_factor=vae.config.scaling_factor,
+ block_out_channels=vae.config.block_out_channels,
+ latent_channels=vae.config.latent_channels,
+)
+consistency_vae.encoder.load_state_dict(vae.encoder.state_dict())
+consistency_vae.quant_conv.load_state_dict(vae.quant_conv.state_dict())
+consistency_vae.decoder_unet.load_state_dict(unet.state_dict())
+
+consistency_vae.to(dtype=torch.float16, device="cuda")
+
+sample_consistency_new_3 = consistency_vae.decode(
+ 0.18215 * latent, generator=torch.Generator("cpu").manual_seed(0)
+).sample
+
+print("max difference")
+print((sample_consistency_orig - sample_consistency_new_3).abs().max())
+print("total difference")
+print((sample_consistency_orig - sample_consistency_new_3).abs().sum())
+# assert (sample_consistency_orig == sample_consistency_new_3).all()
+
+print("running with diffusers pipeline")
+
+pipe = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16
+)
+pipe.to("cuda")
+
+pipe("horse", generator=torch.Generator("cpu").manual_seed(0)).images[0].save("horse.png")
+
+
+if args.save_pretrained is not None:
+ consistency_vae.save_pretrained(args.save_pretrained)
diff --git a/setup.py b/setup.py
index c2c8e75c24ae..f4b14aee49e5 100644
--- a/setup.py
+++ b/setup.py
@@ -244,7 +244,7 @@ def run(self):
setup(
name="diffusers",
- version="0.23.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ version="0.24.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 4291e911ac74..787e3b1c29e7 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.23.0.dev0"
+__version__ = "0.24.0.dev0"
from typing import TYPE_CHECKING
@@ -77,6 +77,7 @@
"AsymmetricAutoencoderKL",
"AutoencoderKL",
"AutoencoderTiny",
+ "ConsistencyDecoderVAE",
"ControlNetModel",
"ModelMixin",
"MotionAdapter",
@@ -443,6 +444,7 @@
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderTiny,
+ ConsistencyDecoderVAE,
ControlNetModel,
ModelMixin,
MotionAdapter,
diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py
index 87e0e164026f..4590c2452b88 100644
--- a/src/diffusers/loaders.py
+++ b/src/diffusers/loaders.py
@@ -1411,6 +1411,11 @@ def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_ext
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
)
+ if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
+ elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
+ targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
+
if len(targeted_files) > 1:
raise ValueError(
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
@@ -2390,7 +2395,7 @@ def unfuse_text_encoder_lora(text_encoder):
def set_adapters_for_text_encoder(
self,
adapter_names: Union[List[str], str],
- text_encoder: Optional[PreTrainedModel] = None,
+ text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
text_encoder_weights: List[float] = None,
):
"""
@@ -2429,7 +2434,7 @@ def process_weights(adapter_names, weights):
)
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
- def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
+ def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
"""
Disables the LoRA layers for the text encoder.
@@ -2446,7 +2451,7 @@ def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel]
raise ValueError("Text Encoder not found.")
set_adapter_layers(text_encoder, enabled=False)
- def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
+ def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
"""
Enables the LoRA layers for the text encoder.
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index f807353312d1..d45f56d43c32 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -24,6 +24,7 @@
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
+ _import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["modeling_utils"] = ["ModelMixin"]
@@ -50,6 +51,7 @@
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_tiny import AutoencoderTiny
+ from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel
from .dual_transformer_2d import DualTransformer2DModel
from .modeling_utils import ModelMixin
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index 9773cafc6947..0c4c5de6e31a 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -287,7 +287,7 @@ def forward(
else:
raise ValueError("Incorrect norm")
- if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index efed305a0e96..1234dbd2d5ce 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -378,7 +378,7 @@ def set_processor(self, processor: "AttnProcessor", _remove_lora: bool = False)
_remove_lora (`bool`, *optional*, defaults to `False`):
Set to `True` to remove LoRA layers from the model.
"""
- if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
+ if not USE_PEFT_BACKEND and hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
deprecate(
"set_processor to offload LoRA",
"0.26.0",
@@ -879,6 +879,9 @@ def __call__(
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
@@ -891,17 +894,17 @@ def __call__(
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- query = attn.to_q(hidden_states, scale=scale)
+ query = attn.to_q(hidden_states, *args)
query = attn.head_to_batch_dim(query)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
- key = attn.to_k(hidden_states, scale=scale)
- value = attn.to_v(hidden_states, scale=scale)
+ key = attn.to_k(hidden_states, *args)
+ value = attn.to_v(hidden_states, *args)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
@@ -915,7 +918,7 @@ def __call__(
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
- hidden_states = attn.to_out[0](hidden_states, scale=scale)
+ hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -946,6 +949,9 @@ def __call__(
scale: float = 1.0,
) -> torch.Tensor:
residual = hidden_states
+
+ args = () if USE_PEFT_BACKEND else (scale,)
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape
@@ -958,7 +964,7 @@ def __call__(
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
- query = attn.to_q(hidden_states, scale=scale)
+ query = attn.to_q(hidden_states, *args)
query = attn.head_to_batch_dim(query, out_dim=4)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
@@ -967,8 +973,8 @@ def __call__(
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
if not attn.only_cross_attention:
- key = attn.to_k(hidden_states, scale=scale)
- value = attn.to_v(hidden_states, scale=scale)
+ key = attn.to_k(hidden_states, *args)
+ value = attn.to_v(hidden_states, *args)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
@@ -985,7 +991,7 @@ def __call__(
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# linear proj
- hidden_states = attn.to_out[0](hidden_states, scale=scale)
+ hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1177,6 +1183,8 @@ def __call__(
) -> torch.FloatTensor:
residual = hidden_states
+ args = () if USE_PEFT_BACKEND else (scale,)
+
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@@ -1207,12 +1215,8 @@ def __call__(
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
- key = (
- attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states)
- )
- value = (
- attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states)
- )
+ key = attn.to_k(encoder_hidden_states, *args)
+ value = attn.to_v(encoder_hidden_states, *args)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
@@ -1232,9 +1236,7 @@ def __call__(
hidden_states = hidden_states.to(query.dtype)
# linear proj
- hidden_states = (
- attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states)
- )
+ hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -1361,6 +1363,7 @@ def __call__(
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
+
return hidden_states
@@ -1433,8 +1436,11 @@ def __call__(
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.train_kv:
- key = self.to_k_custom_diffusion(encoder_hidden_states)
- value = self.to_v_custom_diffusion(encoder_hidden_states)
+ key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
+ value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
+ key = key.to(attn.to_q.weight.dtype)
+ value = value.to(attn.to_q.weight.dtype)
+
else:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
diff --git a/src/diffusers/models/autoencoder_asym_kl.py b/src/diffusers/models/autoencoder_asym_kl.py
index d8099120918b..9f0fa62d34cd 100644
--- a/src/diffusers/models/autoencoder_asym_kl.py
+++ b/src/diffusers/models/autoencoder_asym_kl.py
@@ -138,6 +138,7 @@ def _decode(
def decode(
self,
z: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
image: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py
index 80d2cccd536d..ac616530a66a 100644
--- a/src/diffusers/models/autoencoder_kl.py
+++ b/src/diffusers/models/autoencoder_kl.py
@@ -294,7 +294,9 @@ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decod
return DecoderOutput(sample=dec)
@apply_forward_hook
- def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ def decode(
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
diff --git a/src/diffusers/models/autoencoder_tiny.py b/src/diffusers/models/autoencoder_tiny.py
index 407b1906bba4..15bd53ff99d6 100644
--- a/src/diffusers/models/autoencoder_tiny.py
+++ b/src/diffusers/models/autoencoder_tiny.py
@@ -14,7 +14,7 @@
from dataclasses import dataclass
-from typing import Tuple, Union
+from typing import Optional, Tuple, Union
import torch
@@ -307,7 +307,9 @@ def encode(
return AutoencoderTinyOutput(latents=output)
@apply_forward_hook
- def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
+ def decode(
+ self, x: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
+ ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
if self.use_slicing and x.shape[0] > 1:
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
output = torch.cat(output)
diff --git a/src/diffusers/models/consistency_decoder_vae.py b/src/diffusers/models/consistency_decoder_vae.py
new file mode 100644
index 000000000000..63d8763d14b5
--- /dev/null
+++ b/src/diffusers/models/consistency_decoder_vae.py
@@ -0,0 +1,430 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..schedulers import ConsistencyDecoderScheduler
+from ..utils import BaseOutput
+from ..utils.accelerate_utils import apply_forward_hook
+from ..utils.torch_utils import randn_tensor
+from .attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from .modeling_utils import ModelMixin
+from .unet_2d import UNet2DModel
+from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
+
+
+@dataclass
+class ConsistencyDecoderVAEOutput(BaseOutput):
+ """
+ Output of encoding method.
+
+ Args:
+ latent_dist (`DiagonalGaussianDistribution`):
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
+ """
+
+ latent_dist: "DiagonalGaussianDistribution"
+
+
+class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
+ r"""
+ The consistency decoder used with DALL-E 3.
+
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import DiffusionPipeline, ConsistencyDecoderVAE
+
+ >>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
+ ... ).to("cuda")
+
+ >>> pipe("horse", generator=torch.manual_seed(0)).images
+ ```
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ scaling_factor=0.18215,
+ latent_channels=4,
+ encoder_act_fn="silu",
+ encoder_block_out_channels=(128, 256, 512, 512),
+ encoder_double_z=True,
+ encoder_down_block_types=(
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ ),
+ encoder_in_channels=3,
+ encoder_layers_per_block=2,
+ encoder_norm_num_groups=32,
+ encoder_out_channels=4,
+ decoder_add_attention=False,
+ decoder_block_out_channels=(320, 640, 1024, 1024),
+ decoder_down_block_types=(
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ ),
+ decoder_downsample_padding=1,
+ decoder_in_channels=7,
+ decoder_layers_per_block=3,
+ decoder_norm_eps=1e-05,
+ decoder_norm_num_groups=32,
+ decoder_num_train_timesteps=1024,
+ decoder_out_channels=6,
+ decoder_resnet_time_scale_shift="scale_shift",
+ decoder_time_embedding_type="learned",
+ decoder_up_block_types=(
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ ),
+ ):
+ super().__init__()
+ self.encoder = Encoder(
+ act_fn=encoder_act_fn,
+ block_out_channels=encoder_block_out_channels,
+ double_z=encoder_double_z,
+ down_block_types=encoder_down_block_types,
+ in_channels=encoder_in_channels,
+ layers_per_block=encoder_layers_per_block,
+ norm_num_groups=encoder_norm_num_groups,
+ out_channels=encoder_out_channels,
+ )
+
+ self.decoder_unet = UNet2DModel(
+ add_attention=decoder_add_attention,
+ block_out_channels=decoder_block_out_channels,
+ down_block_types=decoder_down_block_types,
+ downsample_padding=decoder_downsample_padding,
+ in_channels=decoder_in_channels,
+ layers_per_block=decoder_layers_per_block,
+ norm_eps=decoder_norm_eps,
+ norm_num_groups=decoder_norm_num_groups,
+ num_train_timesteps=decoder_num_train_timesteps,
+ out_channels=decoder_out_channels,
+ resnet_time_scale_shift=decoder_resnet_time_scale_shift,
+ time_embedding_type=decoder_time_embedding_type,
+ up_block_types=decoder_up_block_types,
+ )
+ self.decoder_scheduler = ConsistencyDecoderScheduler()
+ self.register_to_config(block_out_channels=encoder_block_out_channels)
+ self.register_buffer(
+ "means",
+ torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
+ persistent=False,
+ )
+ self.register_buffer(
+ "stds", torch.tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None], persistent=False
+ )
+
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
+
+ self.use_slicing = False
+ self.use_tiling = False
+
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_tiling
+ def enable_tiling(self, use_tiling: bool = True):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.use_tiling = use_tiling
+
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_tiling
+ def disable_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.enable_tiling(False)
+
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.enable_slicing
+ def enable_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.disable_slicing
+ def disable_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ @property
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
+ ):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor, _remove_lora=_remove_lora)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor, _remove_lora=True)
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.FloatTensor, return_dict: bool = True
+ ) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.FloatTensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain
+ tuple.
+
+ Returns:
+ The latent representations of the encoded images. If `return_dict` is True, a
+ [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a plain `tuple`
+ is returned.
+ """
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
+ return self.tiled_encode(x, return_dict=return_dict)
+
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self.encoder(x)
+
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return ConsistencyDecoderVAEOutput(latent_dist=posterior)
+
+ @apply_forward_hook
+ def decode(
+ self,
+ z: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ num_inference_steps=2,
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ z = (z * self.config.scaling_factor - self.means) / self.stds
+
+ scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
+ z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)
+
+ batch_size, _, height, width = z.shape
+
+ self.decoder_scheduler.set_timesteps(num_inference_steps, device=self.device)
+
+ x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor(
+ (batch_size, 3, height, width), generator=generator, dtype=z.dtype, device=z.device
+ )
+
+ for t in self.decoder_scheduler.timesteps:
+ model_input = torch.concat([self.decoder_scheduler.scale_model_input(x_t, t), z], dim=1)
+ model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :]
+ prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample
+ x_t = prev_sample
+
+ x_0 = x_t
+
+ if not return_dict:
+ return (x_0,)
+
+ return DecoderOutput(sample=x_0)
+
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_v
+ def blend_v(self, a, b, blend_extent):
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
+ return b
+
+ # Copied from diffusers.models.autoencoder_kl.AutoencoderKL.blend_h
+ def blend_h(self, a, b, blend_extent):
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
+ return b
+
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput:
+ r"""Encode a batch of images using a tiled encoder.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
+ output, but they should be much less noticeable.
+
+ Args:
+ x (`torch.FloatTensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a
+ plain tuple.
+
+ Returns:
+ [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`:
+ If return_dict is True, a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned,
+ otherwise a plain `tuple` is returned.
+ """
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split the image into 512x512 tiles and encode them separately.
+ rows = []
+ for i in range(0, x.shape[2], overlap_size):
+ row = []
+ for j in range(0, x.shape[3], overlap_size):
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
+ tile = self.encoder(tile)
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ moments = torch.cat(result_rows, dim=2)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return ConsistencyDecoderVAEOutput(latent_dist=posterior)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, generator=generator).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py
index 8fe66aacf5db..868e2e5fae2c 100644
--- a/src/diffusers/models/resnet.py
+++ b/src/diffusers/models/resnet.py
@@ -778,16 +778,22 @@ class Conv1dBlock(nn.Module):
out_channels (`int`): Number of output channels.
kernel_size (`int` or `tuple`): Size of the convolving kernel.
n_groups (`int`, default `8`): Number of groups to separate the channels into.
+ activation (`str`, defaults `mish`): Name of the activation function.
"""
def __init__(
- self, inp_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], n_groups: int = 8
+ self,
+ inp_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ n_groups: int = 8,
+ activation: str = "mish",
):
super().__init__()
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.group_norm = nn.GroupNorm(n_groups, out_channels)
- self.mish = nn.Mish()
+ self.mish = get_activation(activation)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
intermediate_repr = self.conv1d(inputs)
@@ -808,16 +814,22 @@ class ResidualTemporalBlock1D(nn.Module):
out_channels (`int`): Number of output channels.
embed_dim (`int`): Embedding dimension.
kernel_size (`int` or `tuple`): Size of the convolving kernel.
+ activation (`str`, defaults `mish`): It is possible to choose the right activation function.
"""
def __init__(
- self, inp_channels: int, out_channels: int, embed_dim: int, kernel_size: Union[int, Tuple[int, int]] = 5
+ self,
+ inp_channels: int,
+ out_channels: int,
+ embed_dim: int,
+ kernel_size: Union[int, Tuple[int, int]] = 5,
+ activation: str = "mish",
):
super().__init__()
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
- self.time_emb_act = nn.Mish()
+ self.time_emb_act = get_activation(activation)
self.time_emb = nn.Linear(embed_dim, out_channels)
self.residual_conv = (
diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py
index 7c0cd12d1c67..24abf54d6da7 100644
--- a/src/diffusers/models/transformer_2d.py
+++ b/src/diffusers/models/transformer_2d.py
@@ -339,6 +339,7 @@ def forward(
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states = self.pos_embed(hidden_states)
if self.adaln_single is not None:
@@ -425,7 +426,8 @@ def forward(
hidden_states = hidden_states.squeeze(1)
# unpatchify
- height = width = int(hidden_states.shape[1] ** 0.5)
+ if self.adaln_single is None:
+ height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py
index 38e26422e2a7..0531d8aae783 100644
--- a/src/diffusers/models/unet_2d.py
+++ b/src/diffusers/models/unet_2d.py
@@ -117,6 +117,7 @@ def __init__(
add_attention: bool = True,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
+ num_train_timesteps: Optional[int] = None,
):
super().__init__()
@@ -144,6 +145,9 @@ def __init__(
elif time_embedding_type == "positional":
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
+ elif time_embedding_type == "learned":
+ self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
+ timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py
index 08ad122c3891..0c93b9142bea 100644
--- a/src/diffusers/models/vq_model.py
+++ b/src/diffusers/models/vq_model.py
@@ -162,8 +162,8 @@ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[
If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
is returned.
"""
- x = sample
- h = self.encode(x).latents
+
+ h = self.encode(sample).latents
dec = self.decode(h).sample
if not return_dict:
diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
index a73dc22a146c..2dbc2604ffce 100644
--- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
+++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
@@ -588,6 +588,34 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
@property
def guidance_scale(self):
return self._guidance_scale
@@ -605,7 +633,7 @@ def clip_skip(self):
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def cross_attention_kwargs(self):
@@ -804,6 +832,14 @@ def __call__(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ # 6.5 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
@@ -818,6 +854,7 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False,
)[0]
@@ -852,7 +889,9 @@ def __call__(
callback(step_idx, t, latents)
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
index ea4a3128dee3..baaadefaad3e 100644
--- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
+++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py
@@ -646,6 +646,34 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
@property
def guidance_scale(self):
return self._guidance_scale
@@ -659,7 +687,7 @@ def clip_skip(self):
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def cross_attention_kwargs(self):
@@ -849,6 +877,14 @@ def __call__(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ # 7.5 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
@@ -863,6 +899,7 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False,
)[0]
@@ -893,7 +930,9 @@ def __call__(
callback(step_idx, t, latents)
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
index 49947f9dbf32..b63acb9a5f30 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
@@ -498,7 +498,7 @@ def prepare_latents(
@torch.no_grad()
def __call__(
self,
- prompt: Union[str, List[str]],
+ prompt: Union[str, List[str]] = None,
num_frames: Optional[int] = 16,
height: Optional[int] = None,
width: Optional[int] = None,
diff --git a/src/diffusers/pipelines/consistency_models/__init__.py b/src/diffusers/pipelines/consistency_models/__init__.py
index 053a3666263f..162d91c010ac 100644
--- a/src/diffusers/pipelines/consistency_models/__init__.py
+++ b/src/diffusers/pipelines/consistency_models/__init__.py
@@ -6,7 +6,9 @@
)
-_import_structure = {"pipeline_consistency_models": ["ConsistencyModelPipeline"]}
+_import_structure = {
+ "pipeline_consistency_models": ["ConsistencyModelPipeline"],
+}
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_consistency_models import ConsistencyModelPipeline
diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
index de1b1fd93c7f..6465250a762a 100644
--- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
+++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
@@ -1,3 +1,17 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from typing import Callable, List, Optional, Union
import torch
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
index 6944d9331253..04ca51b19f05 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
@@ -1058,7 +1058,9 @@ def __call__(
torch.cuda.empty_cache()
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
index b692d936c5b7..8683d09e118d 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
@@ -1138,7 +1138,9 @@ def __call__(
torch.cuda.empty_cache()
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index 3e0b07bf1694..399cfdcf9c2c 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -1405,7 +1405,9 @@ def __call__(
torch.cuda.empty_cache()
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
index e595b3423995..8380dd210d9c 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
@@ -1109,8 +1109,6 @@ def __call__(
nsfw_detected = None
watermark_detected = None
- if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
- self.unet_offload_hook.offload()
else:
# 10. Post-processing
image = (image / 2 + 0.5).clamp(0, 1)
@@ -1119,9 +1117,7 @@ def __call__(
# 11. Run safety checker
image, nsfw_detected, watermark_detected = self.run_safety_checker(image, device, prompt_embeds.dtype)
- # Offload last model to CPU
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
- self.final_offload_hook.offload()
+ self.maybe_free_model_hooks()
if not return_dict:
return (image, nsfw_detected, watermark_detected)
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
index 5c78b0dce87e..5e7a69e756ce 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
@@ -388,6 +388,8 @@ def __call__(
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+ self.maybe_free_model_hooks()
+
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
index 25508e1e080f..eff8af4c723e 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
@@ -321,6 +321,9 @@ def __call__(
callback_steps=callback_steps,
return_dict=return_dict,
)
+
+ self.maybe_free_model_hooks()
+
return outputs
@@ -558,6 +561,9 @@ def __call__(
callback_steps=callback_steps,
return_dict=return_dict,
)
+
+ self.maybe_free_model_hooks()
+
return outputs
@@ -593,7 +599,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
"""
_load_connected_pipes = True
- model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq"
+ model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
def __init__(
self,
@@ -802,4 +808,7 @@ def __call__(
callback_steps=callback_steps,
return_dict=return_dict,
)
+
+ self.maybe_free_model_hooks()
+
return outputs
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
index a22823aadef4..c5e7af270906 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
@@ -481,6 +481,8 @@ def __call__(
# 7. post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+ self.maybe_free_model_hooks()
+
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
index 144e3ce585af..e9b5eb5cdd70 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
@@ -616,6 +616,8 @@ def __call__(
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
+ self.maybe_free_model_hooks()
+
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
index c9a6019a8eac..a9c12b258974 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
@@ -527,7 +527,7 @@ def __call__(
if negative_prompt is None:
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
- self.maybe_free_model_hooks
+ self.maybe_free_model_hooks()
else:
image_embeddings, zero_embeds = image_embeddings.chunk(2)
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
index 097673d904f5..2c7caa6214e5 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
@@ -326,6 +326,8 @@ def __call__(
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
+ self.maybe_free_model_hooks()
+
return outputs
@@ -572,6 +574,8 @@ def __call__(
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
+
+ self.maybe_free_model_hooks()
return outputs
@@ -842,4 +846,6 @@ def __call__(
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
**kwargs,
)
+ self.maybe_free_model_hooks()
+
return outputs
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
index 345b3ae65721..8d0e788b9dd9 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
@@ -531,14 +531,10 @@ def __call__(
# if negative prompt has been defined, we retrieve split the image embedding into two
if negative_prompt is None:
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
-
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
- self.final_offload_hook.offload()
else:
image_embeddings, zero_embeds = image_embeddings.chunk(2)
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
- self.prior_hook.offload()
+ self.maybe_free_model_hooks()
if output_type not in ["pt", "np"]:
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
index b4a6a64137ec..bef70821c605 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
@@ -545,12 +545,10 @@ def __call__(
# if negative prompt has been defined, we retrieve split the image embedding into two
if negative_prompt is None:
zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device)
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
- self.final_offload_hook.offload()
else:
image_embeddings, zero_embeds = image_embeddings.chunk(2)
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
- self.prior_hook.offload()
+
+ self.maybe_free_model_hooks()
if output_type not in ["pt", "np"]:
raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/latent_consistency_models/__init__.py b/src/diffusers/pipelines/latent_consistency_models/__init__.py
index 14002058cdfd..8f79d3c4773f 100644
--- a/src/diffusers/pipelines/latent_consistency_models/__init__.py
+++ b/src/diffusers/pipelines/latent_consistency_models/__init__.py
@@ -1,19 +1,40 @@
from typing import TYPE_CHECKING
from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
_LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
)
-_import_structure = {
- "pipeline_latent_consistency_img2img": ["LatentConsistencyModelImg2ImgPipeline"],
- "pipeline_latent_consistency_text2img": ["LatentConsistencyModelPipeline"],
-}
+_dummy_objects = {}
+_import_structure = {}
-if TYPE_CHECKING:
- from .pipeline_latent_consistency_img2img import LatentConsistencyModelImg2ImgPipeline
- from .pipeline_latent_consistency_text2img import LatentConsistencyModelPipeline
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_latent_consistency_img2img"] = ["LatentConsistencyModelImg2ImgPipeline"]
+ _import_structure["pipeline_latent_consistency_text2img"] = ["LatentConsistencyModelPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_latent_consistency_img2img import LatentConsistencyModelImg2ImgPipeline
+ from .pipeline_latent_consistency_text2img import LatentConsistencyModelPipeline
else:
import sys
@@ -24,3 +45,6 @@
_import_structure,
module_spec=__spec__,
)
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
index ccc84e22c252..679415db7f3a 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
@@ -60,7 +60,7 @@ def retrieve_latents(encoder_output, generator):
>>> import torch
>>> import PIL
- >>> pipe = DiffusionPipeline.from_pretrained("SimianLuo/LCM_Dreamshaper_v7")
+ >>> pipe = AutoPipelineForImage2Image.from_pretrained("SimianLuo/LCM_Dreamshaper_v7")
>>> # To save GPU memory, torch.float16 can be used, but it may compromise image quality.
>>> pipe.to(torch_device="cuda", torch_dtype=torch.float32)
@@ -738,7 +738,7 @@ def __call__(
if original_inference_steps is not None
else self.scheduler.config.original_inference_steps
)
- latent_timestep = torch.tensor(int(strength * original_inference_steps))
+ latent_timestep = timesteps[:1]
latents = self.prepare_latents(
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
)
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index 9e019794692e..6437732d0315 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -158,9 +158,9 @@ def is_safetensors_compatible(filenames, variant=None, passed_components=None) -
continue
if extension == ".bin":
- pt_filenames.append(filename)
+ pt_filenames.append(os.path.normpath(filename))
elif extension == ".safetensors":
- sf_filenames.add(filename)
+ sf_filenames.add(os.path.normpath(filename))
for filename in pt_filenames:
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam'
@@ -172,9 +172,8 @@ def is_safetensors_compatible(filenames, variant=None, passed_components=None) -
else:
filename = filename
- expected_sf_filename = os.path.join(path, filename)
+ expected_sf_filename = os.path.normpath(os.path.join(path, filename))
expected_sf_filename = f"{expected_sf_filename}.safetensors"
-
if expected_sf_filename not in sf_filenames:
logger.warning(f"{expected_sf_filename} not found")
return False
@@ -1774,7 +1773,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
)
):
raise EnvironmentError(
- f"Could not found the necessary `safetensors` weights in {model_filenames} (variant={variant})"
+ f"Could not find the necessary `safetensors` weights in {model_filenames} (variant={variant})"
)
if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
diff --git a/src/diffusers/pipelines/pixart_alpha/__init__.py b/src/diffusers/pipelines/pixart_alpha/__init__.py
index e0d238907a06..0bfa28fcde50 100644
--- a/src/diffusers/pipelines/pixart_alpha/__init__.py
+++ b/src/diffusers/pipelines/pixart_alpha/__init__.py
@@ -1 +1,48 @@
-from .pipeline_pixart_alpha import PixArtAlphaPipeline
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_pixart_alpha"] = ["PixArtAlphaPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_pixart_alpha import PixArtAlphaPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
index 1f39cc168c6f..147e2b76e6c6 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
@@ -156,6 +156,8 @@ def encode_prompt(
mask_feature: (bool, defaults to `True`):
If `True`, the function will mask the text embeddings.
"""
+ embeds_initially_provided = prompt_embeds is not None and negative_prompt_embeds is not None
+
if device is None:
device = self._execution_device
@@ -253,7 +255,7 @@ def encode_prompt(
negative_prompt_embeds = None
# Perform additional masking.
- if mask_feature:
+ if mask_feature and not embeds_initially_provided:
prompt_embeds = prompt_embeds.unsqueeze(1)
masked_prompt_embeds, keep_indices = self.mask_text_embeddings(prompt_embeds, prompt_embeds_attention_mask)
masked_prompt_embeds = masked_prompt_embeds.squeeze(1)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
index 8a5eb066f4fa..9bdb6d824f99 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py
@@ -918,6 +918,7 @@ def __call__(
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+ self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index fe85c6391fa9..e9f49c1a7641 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -581,6 +581,35 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
@property
def guidance_scale(self):
return self._guidance_scale
@@ -598,7 +627,7 @@ def clip_skip(self):
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def cross_attention_kwargs(self):
@@ -795,6 +824,14 @@ def __call__(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ # 6.5 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
@@ -809,6 +846,7 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False,
)[0]
@@ -843,7 +881,9 @@ def __call__(
callback(step_idx, t, latents)
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
index f897b51941a6..2e040306abfd 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py
@@ -1027,6 +1027,7 @@ def __call__(
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+ self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
index 7f6845128f6c..36efb01f23ef 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
@@ -846,6 +846,7 @@ def __call__(
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
+ self.maybe_free_model_hooks()
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
index c6797a0693cc..e8f48a163066 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
@@ -439,6 +439,8 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+ self.maybe_free_model_hooks()
+
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
index 583e6046b2e1..40daecfa913f 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -640,6 +640,35 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
@property
def guidance_scale(self):
return self._guidance_scale
@@ -653,7 +682,7 @@ def clip_skip(self):
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def cross_attention_kwargs(self):
@@ -841,6 +870,14 @@ def __call__(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ # 7.5 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
@@ -855,6 +892,7 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False,
)[0]
@@ -885,7 +923,9 @@ def __call__(
callback(step_idx, t, latents)
if not output_type == "latent":
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index e306463a4447..1b9eae657420 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -766,6 +766,35 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
@property
def guidance_scale(self):
return self._guidance_scale
@@ -779,7 +808,7 @@ def clip_skip(self):
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def cross_attention_kwargs(self):
@@ -1088,6 +1117,14 @@ def __call__(
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+ # 9.5 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
# 10. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
@@ -1107,6 +1144,7 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False,
)[0]
@@ -1160,7 +1198,9 @@ def __call__(
init_image = self._encode_vae_image(init_image, generator=generator)
mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0]
+ image = self.vae.decode(
+ latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs
+ )[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
index 1e8c98c44750..4cde54ac587a 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
@@ -511,6 +511,8 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type)
+ self.maybe_free_model_hooks()
+
if not return_dict:
return (image,)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
index f53e34e9259a..ce3e694e7e32 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py
@@ -802,6 +802,8 @@ def __call__(
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+ self.maybe_free_model_hooks()
+
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
index 80f1d49ae297..56eb38c653ba 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py
@@ -741,6 +741,8 @@ def get_map_size(module, input, output):
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+ self.maybe_free_model_hooks()
+
if not return_dict:
return (image, has_nsfw_concept)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
index c81dd85f0e46..eb4542888c1f 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
@@ -206,17 +206,15 @@ def _encode_prior_prompt(
prior_text_encoder_output = self.prior_text_encoder(text_input_ids.to(device))
prompt_embeds = prior_text_encoder_output.text_embeds
- prior_text_encoder_hidden_states = prior_text_encoder_output.last_hidden_state
+ text_enc_hid_states = prior_text_encoder_output.last_hidden_state
else:
batch_size = text_model_output[0].shape[0]
- prompt_embeds, prior_text_encoder_hidden_states = text_model_output[0], text_model_output[1]
+ prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
- prior_text_encoder_hidden_states = prior_text_encoder_hidden_states.repeat_interleave(
- num_images_per_prompt, dim=0
- )
+ text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
@@ -235,9 +233,7 @@ def _encode_prior_prompt(
)
negative_prompt_embeds = negative_prompt_embeds_prior_text_encoder_output.text_embeds
- uncond_prior_text_encoder_hidden_states = (
- negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
- )
+ uncond_text_enc_hid_states = negative_prompt_embeds_prior_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@@ -245,11 +241,9 @@ def _encode_prior_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
- seq_len = uncond_prior_text_encoder_hidden_states.shape[1]
- uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.repeat(
- 1, num_images_per_prompt, 1
- )
- uncond_prior_text_encoder_hidden_states = uncond_prior_text_encoder_hidden_states.view(
+ seq_len = uncond_text_enc_hid_states.shape[1]
+ uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
@@ -260,13 +254,11 @@ def _encode_prior_prompt(
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- prior_text_encoder_hidden_states = torch.cat(
- [uncond_prior_text_encoder_hidden_states, prior_text_encoder_hidden_states]
- )
+ text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
text_mask = torch.cat([uncond_text_mask, text_mask])
- return prompt_embeds, prior_text_encoder_hidden_states, text_mask
+ return prompt_embeds, text_enc_hid_states, text_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
index 6143d2210c3c..151cbed4e08f 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
@@ -636,6 +636,35 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
@property
def guidance_scale(self):
return self._guidance_scale
@@ -653,7 +682,7 @@ def clip_skip(self):
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def cross_attention_kwargs(self):
@@ -989,6 +1018,14 @@ def __call__(
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
+ # 9. Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -1003,6 +1040,7 @@ def __call__(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
index 02a220fa851b..f444eddec0ab 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
@@ -763,6 +763,35 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
@property
def guidance_scale(self):
return self._guidance_scale
@@ -780,7 +809,7 @@ def clip_skip(self):
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def cross_attention_kwargs(self):
@@ -1156,6 +1185,15 @@ def denoising_value_valid(dnv):
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
+
+ # 9.2 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -1170,6 +1208,7 @@ def denoising_value_valid(dnv):
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index 7890774c7539..667e7aec00ed 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -982,6 +982,35 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
@property
def guidance_scale(self):
return self._guidance_scale
@@ -999,7 +1028,7 @@ def clip_skip(self):
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
- return self._guidance_scale > 1
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def cross_attention_kwargs(self):
@@ -1464,6 +1493,14 @@ def denoising_value_valid(dnv):
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
+ # 11.1 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -1482,6 +1519,7 @@ def denoising_value_valid(dnv):
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py
index aafbfe439f32..e7a4f1723cff 100644
--- a/src/diffusers/pipelines/unclip/pipeline_unclip.py
+++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py
@@ -156,15 +156,15 @@ def _encode_prompt(
text_encoder_output = self.text_encoder(text_input_ids.to(device))
prompt_embeds = text_encoder_output.text_embeds
- text_encoder_hidden_states = text_encoder_output.last_hidden_state
+ text_enc_hid_states = text_encoder_output.last_hidden_state
else:
batch_size = text_model_output[0].shape[0]
- prompt_embeds, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
+ prompt_embeds, text_enc_hid_states = text_model_output[0], text_model_output[1]
text_mask = text_attention_mask
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
- text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_enc_hid_states = text_enc_hid_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
@@ -181,7 +181,7 @@ def _encode_prompt(
negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds
- uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
+ uncond_text_enc_hid_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@@ -189,9 +189,9 @@ def _encode_prompt(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len)
- seq_len = uncond_text_encoder_hidden_states.shape[1]
- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
- uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ seq_len = uncond_text_enc_hid_states.shape[1]
+ uncond_text_enc_hid_states = uncond_text_enc_hid_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_enc_hid_states = uncond_text_enc_hid_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
@@ -202,11 +202,11 @@ def _encode_prompt(
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
- text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+ text_enc_hid_states = torch.cat([uncond_text_enc_hid_states, text_enc_hid_states])
text_mask = torch.cat([uncond_text_mask, text_mask])
- return prompt_embeds, text_encoder_hidden_states, text_mask
+ return prompt_embeds, text_enc_hid_states, text_mask
@torch.no_grad()
def __call__(
@@ -301,7 +301,7 @@ def __call__(
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
- prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt(
+ prompt_embeds, text_enc_hid_states, text_mask = self._encode_prompt(
prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
)
current_step = 0
@@ -329,7 +329,7 @@ def __call__(
latent_model_input,
timestep=t,
proj_embedding=prompt_embeds,
- encoder_hidden_states=text_encoder_hidden_states,
+ encoder_hidden_states=text_enc_hid_states,
attention_mask=text_mask,
).predicted_image_embedding
@@ -365,10 +365,10 @@ def __call__(
# decoder
- text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
+ text_enc_hid_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings,
prompt_embeds=prompt_embeds,
- text_encoder_hidden_states=text_encoder_hidden_states,
+ text_encoder_hidden_states=text_enc_hid_states,
do_classifier_free_guidance=do_classifier_free_guidance,
)
@@ -390,7 +390,7 @@ def __call__(
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
- text_encoder_hidden_states.dtype,
+ text_enc_hid_states.dtype,
device,
generator,
decoder_latents,
@@ -404,7 +404,7 @@ def __call__(
noise_pred = self.decoder(
sample=latent_model_input,
timestep=t,
- encoder_hidden_states=text_encoder_hidden_states,
+ encoder_hidden_states=text_enc_hid_states,
class_labels=additive_clip_time_embeddings,
attention_mask=decoder_text_mask,
).sample
diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
index 32147ffa455b..60ea3d814b3a 100644
--- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py
@@ -1494,7 +1494,6 @@ def forward(self, input_tensor, temb):
return output_tensor
-# Copied from diffusers.models.unet_2d_blocks.DownBlock2D with DownBlock2D->DownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
class DownBlockFlat(nn.Module):
def __init__(
self,
@@ -1583,7 +1582,6 @@ def custom_forward(*inputs):
return hidden_states, output_states
-# Copied from diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D with CrossAttnDownBlock2D->CrossAttnDownBlockFlat, ResnetBlock2D->ResnetBlockFlat, Downsample2D->LinearMultiDim
class CrossAttnDownBlockFlat(nn.Module):
def __init__(
self,
diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py
index 85fd9d25e5da..5e5102e589d4 100644
--- a/src/diffusers/schedulers/__init__.py
+++ b/src/diffusers/schedulers/__init__.py
@@ -38,6 +38,7 @@
_dummy_modules.update(get_objects_from_module(dummy_pt_objects))
else:
+ _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
@@ -128,6 +129,7 @@
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
+ from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddim_inverse import DDIMInverseScheduler
diff --git a/src/diffusers/schedulers/scheduling_consistency_decoder.py b/src/diffusers/schedulers/scheduling_consistency_decoder.py
new file mode 100644
index 000000000000..69ca8a1737ec
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_consistency_decoder.py
@@ -0,0 +1,180 @@
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from ..utils.torch_utils import randn_tensor
+from .scheduling_utils import SchedulerMixin
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+@dataclass
+class ConsistencyDecoderSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1024,
+ sigma_data: float = 0.5,
+ ):
+ betas = betas_for_alpha_bar(num_train_timesteps)
+
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+
+ self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
+
+ sigmas = torch.sqrt(1.0 / alphas_cumprod - 1)
+
+ sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod)
+
+ self.c_skip = sqrt_recip_alphas_cumprod * sigma_data**2 / (sigmas**2 + sigma_data**2)
+ self.c_out = sigmas * sigma_data / (sigmas**2 + sigma_data**2) ** 0.5
+ self.c_in = sqrt_recip_alphas_cumprod / (sigmas**2 + sigma_data**2) ** 0.5
+
+ def set_timesteps(
+ self,
+ num_inference_steps: Optional[int] = None,
+ device: Union[str, torch.device] = None,
+ ):
+ if num_inference_steps != 2:
+ raise ValueError("Currently more than 2 inference steps are not supported.")
+
+ self.timesteps = torch.tensor([1008, 512], dtype=torch.long, device=device)
+ self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
+ self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
+ self.c_skip = self.c_skip.to(device)
+ self.c_out = self.c_out.to(device)
+ self.c_in = self.c_in.to(device)
+
+ @property
+ def init_noise_sigma(self):
+ return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]]
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.FloatTensor`:
+ A scaled input sample.
+ """
+ return sample * self.c_in[timestep]
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: Union[float, torch.FloatTensor],
+ sample: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[ConsistencyDecoderSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from the learned diffusion model.
+ timestep (`float`):
+ The current timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a
+ [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`:
+ If return_dict is `True`,
+ [`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
+ a tuple is returned where the first element is the sample tensor.
+ """
+ x_0 = self.c_out[timestep] * model_output + self.c_skip[timestep] * sample
+
+ timestep_idx = torch.where(self.timesteps == timestep)[0]
+
+ if timestep_idx == len(self.timesteps) - 1:
+ prev_sample = x_0
+ else:
+ noise = randn_tensor(x_0.shape, generator=generator, dtype=x_0.dtype, device=x_0.device)
+ prev_sample = (
+ self.sqrt_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * x_0
+ + self.sqrt_one_minus_alphas_cumprod[self.timesteps[timestep_idx + 1]].to(x_0.dtype) * noise
+ )
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return ConsistencyDecoderSchedulerOutput(prev_sample=prev_sample)
diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py
index 8e2627b6f477..adcc092a816f 100644
--- a/src/diffusers/schedulers/scheduling_lcm.py
+++ b/src/diffusers/schedulers/scheduling_lcm.py
@@ -182,6 +182,10 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ timestep_scaling (`float`, defaults to 10.0):
+ The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
+ `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
+ error at the default of `10.0` is already pretty small).
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
@@ -208,6 +212,7 @@ def __init__(
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
timestep_spacing: str = "leading",
+ timestep_scaling: float = 10.0,
rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
@@ -380,12 +385,12 @@ def set_timesteps(
self._step_index = None
- def get_scalings_for_boundary_condition_discrete(self, t):
+ def get_scalings_for_boundary_condition_discrete(self, timestep):
self.sigma_data = 0.5 # Default: 0.5
+ scaled_timestep = timestep * self.config.timestep_scaling
- # By dividing 0.1: This is almost a delta function at t=0.
- c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2)
- c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5
+ c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
return c_skip, c_out
def step(
@@ -466,9 +471,12 @@ def step(
denoised = c_out * predicted_original_sample + c_skip * sample
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
- # Noise is not used for one-step sampling.
- if len(self.timesteps) > 1:
- noise = randn_tensor(model_output.shape, generator=generator, device=model_output.device)
+ # Noise is not used on the final timestep of the timestep schedule.
+ # This also means that noise is not used for one-step sampling.
+ if self.step_index != self.num_inference_steps - 1:
+ noise = randn_tensor(
+ model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
+ )
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
else:
prev_sample = denoised
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index d6d74a89cafb..090b1081fdaf 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -47,6 +47,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class ConsistencyDecoderVAE(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class ControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py
index 6217d1cd28cd..68e986790d76 100644
--- a/tests/lora/test_lora_layers_peft.py
+++ b/tests/lora/test_lora_layers_peft.py
@@ -28,10 +28,12 @@
from diffusers import (
AutoencoderKL,
+ AutoPipelineForImage2Image,
ControlNetModel,
DDIMScheduler,
DiffusionPipeline,
EulerDiscreteScheduler,
+ LCMScheduler,
StableDiffusionPipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLPipeline,
@@ -107,10 +109,12 @@ class PeftLoraLoaderMixinTests:
unet_kwargs = None
vae_kwargs = None
- def get_dummy_components(self):
+ def get_dummy_components(self, scheduler_cls=None):
+ scheduler_cls = self.scheduler_cls if scheduler_cls is None else LCMScheduler
+
torch.manual_seed(0)
unet = UNet2DConditionModel(**self.unet_kwargs)
- scheduler = self.scheduler_cls(**self.scheduler_kwargs)
+ scheduler = scheduler_cls(**self.scheduler_kwargs)
torch.manual_seed(0)
vae = AutoencoderKL(**self.vae_kwargs)
text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2")
@@ -200,746 +204,806 @@ def test_simple_inference(self):
"""
Tests a simple inference and makes sure it works as expected
"""
- components, _, _, _ = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs()
- output_no_lora = pipe(**inputs).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ _, _, inputs = self.get_dummy_inputs()
+ output_no_lora = pipe(**inputs).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
def test_simple_inference_with_text_lora(self):
"""
Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected
"""
- components, _, text_lora_config, _ = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
+ pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
- output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(
- not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
- )
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
+
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(
+ not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ )
def test_simple_inference_with_text_lora_and_scale(self):
"""
Tests a simple inference with lora attached on the text encoder + scale argument
and makes sure it works as expected
"""
- components, _, text_lora_config, _ = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
+ pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
- output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(
- not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
- )
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- output_lora_scale = pipe(
- **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
- ).images
- self.assertTrue(
- not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
- "Lora + scale should change the output",
- )
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(
+ not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ )
- output_lora_0_scale = pipe(
- **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
- ).images
- self.assertTrue(
- np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
- "Lora + 0 scale should lead to same result as no LoRA",
- )
+ output_lora_scale = pipe(
+ **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
+ ).images
+ self.assertTrue(
+ not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
+ "Lora + scale should change the output",
+ )
+
+ output_lora_0_scale = pipe(
+ **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
+ ).images
+ self.assertTrue(
+ np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
+ "Lora + 0 scale should lead to same result as no LoRA",
+ )
def test_simple_inference_with_text_lora_fused(self):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected
"""
- components, _, text_lora_config, _ = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
-
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
+ pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
- pipe.fuse_lora()
- # Fusing should still keep the LoRA layers
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- if self.has_two_text_encoders:
+ pipe.fuse_lora()
+ # Fusing should still keep the LoRA layers
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
- ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertFalse(
- np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
- )
+ if self.has_two_text_encoders:
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
+
+ ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertFalse(
+ np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
+ )
def test_simple_inference_with_text_lora_unloaded(self):
"""
Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected
"""
- components, _, text_lora_config, _ = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
+ pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
- pipe.unload_lora_weights()
- # unloading should remove the LoRA layers
- self.assertFalse(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
- )
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- if self.has_two_text_encoders:
+ pipe.unload_lora_weights()
+ # unloading should remove the LoRA layers
self.assertFalse(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
)
- ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(
- np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
- )
+ if self.has_two_text_encoders:
+ self.assertFalse(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2),
+ "Lora not correctly unloaded in text encoder 2",
+ )
+
+ ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(
+ np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
+ "Fused lora should change the output",
+ )
def test_simple_inference_with_text_lora_save_load(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA.
"""
- components, _, text_lora_config, _ = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
-
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
+ pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
- images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
if self.has_two_text_encoders:
- text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
-
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname,
- text_encoder_lora_layers=text_encoder_state_dict,
- text_encoder_2_lora_layers=text_encoder_2_state_dict,
- safe_serialization=False,
- )
- else:
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname,
- text_encoder_lora_layers=text_encoder_state_dict,
- safe_serialization=False,
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
- pipe.unload_lora_weights()
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
+ if self.has_two_text_encoders:
+ text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
- images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname,
+ text_encoder_lora_layers=text_encoder_state_dict,
+ text_encoder_2_lora_layers=text_encoder_2_state_dict,
+ safe_serialization=False,
+ )
+ else:
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname,
+ text_encoder_lora_layers=text_encoder_state_dict,
+ safe_serialization=False,
+ )
- if self.has_two_text_encoders:
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
+ pipe.unload_lora_weights()
+
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
+
+ images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
- self.assertTrue(
- np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
- )
+ if self.has_two_text_encoders:
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
+
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
def test_simple_inference_save_pretrained(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
- components, _, text_lora_config, _ = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
+ pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
- images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- with tempfile.TemporaryDirectory() as tmpdirname:
- pipe.save_pretrained(tmpdirname)
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
- pipe_from_pretrained.to(self.torch_device)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pipe.save_pretrained(tmpdirname)
- self.assertTrue(
- self.check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
- "Lora not correctly set in text encoder",
- )
+ pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
+ pipe_from_pretrained.to(self.torch_device)
- if self.has_two_text_encoders:
self.assertTrue(
- self.check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
- "Lora not correctly set in text encoder 2",
+ self.check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
+ "Lora not correctly set in text encoder",
)
- images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0)).images
+ if self.has_two_text_encoders:
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
+ "Lora not correctly set in text encoder 2",
+ )
- self.assertTrue(
- np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
- )
+ images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0)).images
+
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
def test_simple_inference_with_text_unet_lora_save_load(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config)
- pipe.unet.add_adapter(unet_lora_config)
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ pipe.text_encoder.add_adapter(text_lora_config)
+ pipe.unet.add_adapter(unet_lora_config)
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
- images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
- unet_state_dict = get_peft_model_state_dict(pipe.unet)
if self.has_two_text_encoders:
- text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
-
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname,
- text_encoder_lora_layers=text_encoder_state_dict,
- text_encoder_2_lora_layers=text_encoder_2_state_dict,
- unet_lora_layers=unet_state_dict,
- safe_serialization=False,
- )
- else:
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname,
- text_encoder_lora_layers=text_encoder_state_dict,
- unet_lora_layers=unet_state_dict,
- safe_serialization=False,
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
- pipe.unload_lora_weights()
-
- pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
+ unet_state_dict = get_peft_model_state_dict(pipe.unet)
+ if self.has_two_text_encoders:
+ text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
+
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname,
+ text_encoder_lora_layers=text_encoder_state_dict,
+ text_encoder_2_lora_layers=text_encoder_2_state_dict,
+ unet_lora_layers=unet_state_dict,
+ safe_serialization=False,
+ )
+ else:
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname,
+ text_encoder_lora_layers=text_encoder_state_dict,
+ unet_lora_layers=unet_state_dict,
+ safe_serialization=False,
+ )
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
+ pipe.unload_lora_weights()
+
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
+
+ images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ )
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
- images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ if self.has_two_text_encoders:
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- if self.has_two_text_encoders:
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
)
- self.assertTrue(
- np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
- )
-
def test_simple_inference_with_text_unet_lora_and_scale(self):
"""
Tests a simple inference with lora attached on the text encoder + Unet + scale argument
and makes sure it works as expected
"""
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+
+ pipe.text_encoder.add_adapter(text_lora_config)
+ pipe.unet.add_adapter(unet_lora_config)
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ )
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
- pipe.text_encoder.add_adapter(text_lora_config)
- pipe.unet.add_adapter(unet_lora_config)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
- output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(
- not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
- )
-
- output_lora_scale = pipe(
- **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
- ).images
- self.assertTrue(
- not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
- "Lora + scale should change the output",
- )
+ output_lora_scale = pipe(
+ **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
+ ).images
+ self.assertTrue(
+ not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
+ "Lora + scale should change the output",
+ )
- output_lora_0_scale = pipe(
- **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
- ).images
- self.assertTrue(
- np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
- "Lora + 0 scale should lead to same result as no LoRA",
- )
+ output_lora_0_scale = pipe(
+ **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
+ ).images
+ self.assertTrue(
+ np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
+ "Lora + 0 scale should lead to same result as no LoRA",
+ )
- self.assertTrue(
- pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
- "The scaling parameter has not been correctly restored!",
- )
+ self.assertTrue(
+ pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
+ "The scaling parameter has not been correctly restored!",
+ )
def test_simple_inference_with_text_lora_unet_fused(self):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet
"""
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config)
- pipe.unet.add_adapter(unet_lora_config)
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ pipe.text_encoder.add_adapter(text_lora_config)
+ pipe.unet.add_adapter(unet_lora_config)
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
- pipe.fuse_lora()
- # Fusing should still keep the LoRA layers
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet")
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- if self.has_two_text_encoders:
+ pipe.fuse_lora()
+ # Fusing should still keep the LoRA layers
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet")
- ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertFalse(
- np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
- )
+ if self.has_two_text_encoders:
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
+
+ ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertFalse(
+ np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
+ )
def test_simple_inference_with_text_unet_lora_unloaded(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
-
- pipe.text_encoder.add_adapter(text_lora_config)
- pipe.unet.add_adapter(unet_lora_config)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
-
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
+
+ pipe.text_encoder.add_adapter(text_lora_config)
+ pipe.unet.add_adapter(unet_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
- pipe.unload_lora_weights()
- # unloading should remove the LoRA layers
- self.assertFalse(
- self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
- )
- self.assertFalse(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly unloaded in Unet")
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- if self.has_two_text_encoders:
+ pipe.unload_lora_weights()
+ # unloading should remove the LoRA layers
self.assertFalse(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly unloaded in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
)
+ self.assertFalse(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly unloaded in Unet")
- ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(
- np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
- )
+ if self.has_two_text_encoders:
+ self.assertFalse(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2),
+ "Lora not correctly unloaded in text encoder 2",
+ )
+
+ ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(
+ np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
+ "Fused lora should change the output",
+ )
def test_simple_inference_with_text_unet_lora_unfused(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config)
- pipe.unet.add_adapter(unet_lora_config)
+ pipe.text_encoder.add_adapter(text_lora_config)
+ pipe.unet.add_adapter(unet_lora_config)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
-
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
- pipe.fuse_lora()
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ pipe.fuse_lora()
- pipe.unfuse_lora()
+ output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- # unloading should remove the LoRA layers
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Unfuse should still keep LoRA layers")
+ pipe.unfuse_lora()
- if self.has_two_text_encoders:
+ output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ # unloading should remove the LoRA layers
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers"
)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Unfuse should still keep LoRA layers")
- # Fuse and unfuse should lead to the same results
- self.assertTrue(
- np.allclose(output_fused_lora, output_unfused_lora, atol=1e-3, rtol=1e-3),
- "Fused lora should change the output",
- )
+ if self.has_two_text_encoders:
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
+ )
+
+ # Fuse and unfuse should lead to the same results
+ self.assertTrue(
+ np.allclose(output_fused_lora, output_unfused_lora, atol=1e-3, rtol=1e-3),
+ "Fused lora should change the output",
+ )
def test_simple_inference_with_text_unet_multi_adapter(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- pipe.unet.add_adapter(unet_lora_config, "adapter-1")
- pipe.unet.add_adapter(unet_lora_config, "adapter-2")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-1")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-2")
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
- pipe.set_adapters("adapter-1")
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
+ pipe.set_adapters("adapter-1")
- pipe.set_adapters("adapter-2")
- output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
+ output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
- pipe.set_adapters(["adapter-1", "adapter-2"])
+ pipe.set_adapters("adapter-2")
+ output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
- output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
+ pipe.set_adapters(["adapter-1", "adapter-2"])
- # Fuse and unfuse should lead to the same results
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
+ output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 1 and mixed adapters should give different results",
- )
+ # Fuse and unfuse should lead to the same results
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and 2 should give different results",
+ )
- self.assertFalse(
- np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 2 and mixed adapters should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and mixed adapters should give different results",
+ )
- pipe.disable_lora()
+ self.assertFalse(
+ np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 2 and mixed adapters should give different results",
+ )
- output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
+ pipe.disable_lora()
- self.assertTrue(
- np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
+ output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ self.assertTrue(
+ np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
def test_simple_inference_with_text_unet_multi_adapter_weighted(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- pipe.unet.add_adapter(unet_lora_config, "adapter-1")
- pipe.unet.add_adapter(unet_lora_config, "adapter-2")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-1")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-2")
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
- pipe.set_adapters("adapter-1")
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
+ pipe.set_adapters("adapter-1")
- pipe.set_adapters("adapter-2")
- output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
+ output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
- pipe.set_adapters(["adapter-1", "adapter-2"])
+ pipe.set_adapters("adapter-2")
+ output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
- output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
+ pipe.set_adapters(["adapter-1", "adapter-2"])
- # Fuse and unfuse should lead to the same results
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
+ output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 1 and mixed adapters should give different results",
- )
+ # Fuse and unfuse should lead to the same results
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and 2 should give different results",
+ )
- self.assertFalse(
- np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 2 and mixed adapters should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and mixed adapters should give different results",
+ )
- pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
- output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertFalse(
+ np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 2 and mixed adapters should give different results",
+ )
- self.assertFalse(
- np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Weighted adapter and mixed adapter should give different results",
- )
+ pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
+ output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0)).images
- pipe.disable_lora()
+ self.assertFalse(
+ np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Weighted adapter and mixed adapter should give different results",
+ )
- output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
+ pipe.disable_lora()
- self.assertTrue(
- np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
+ output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
- def test_lora_fuse_nan(self):
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ self.assertTrue(
+ np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ def test_lora_fuse_nan(self):
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.unet.add_adapter(unet_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-1")
- # corrupt one LoRA weight with `inf` values
- with torch.no_grad():
- pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
- "inf"
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
+
+ # corrupt one LoRA weight with `inf` values
+ with torch.no_grad():
+ pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
+ "inf"
+ )
- # with `safe_fusing=True` we should see an Error
- with self.assertRaises(ValueError):
- pipe.fuse_lora(safe_fusing=True)
+ # with `safe_fusing=True` we should see an Error
+ with self.assertRaises(ValueError):
+ pipe.fuse_lora(safe_fusing=True)
- # without we should not see an error, but every image will be black
- pipe.fuse_lora(safe_fusing=False)
+ # without we should not see an error, but every image will be black
+ pipe.fuse_lora(safe_fusing=False)
- out = pipe("test", num_inference_steps=2, output_type="np").images
+ out = pipe("test", num_inference_steps=2, output_type="np").images
- self.assertTrue(np.isnan(out).all())
+ self.assertTrue(np.isnan(out).all())
def test_get_adapters(self):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.unet.add_adapter(unet_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-1")
- adapter_names = pipe.get_active_adapters()
- self.assertListEqual(adapter_names, ["adapter-1"])
+ adapter_names = pipe.get_active_adapters()
+ self.assertListEqual(adapter_names, ["adapter-1"])
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- pipe.unet.add_adapter(unet_lora_config, "adapter-2")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-2")
- adapter_names = pipe.get_active_adapters()
- self.assertListEqual(adapter_names, ["adapter-2"])
+ adapter_names = pipe.get_active_adapters()
+ self.assertListEqual(adapter_names, ["adapter-2"])
- pipe.set_adapters(["adapter-1", "adapter-2"])
- self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])
def test_get_list_adapters(self):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.unet.add_adapter(unet_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-1")
- adapter_names = pipe.get_list_adapters()
- self.assertDictEqual(adapter_names, {"text_encoder": ["adapter-1"], "unet": ["adapter-1"]})
+ adapter_names = pipe.get_list_adapters()
+ self.assertDictEqual(adapter_names, {"text_encoder": ["adapter-1"], "unet": ["adapter-1"]})
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- pipe.unet.add_adapter(unet_lora_config, "adapter-2")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ pipe.unet.add_adapter(unet_lora_config, "adapter-2")
- adapter_names = pipe.get_list_adapters()
- self.assertDictEqual(
- adapter_names, {"text_encoder": ["adapter-1", "adapter-2"], "unet": ["adapter-1", "adapter-2"]}
- )
+ adapter_names = pipe.get_list_adapters()
+ self.assertDictEqual(
+ adapter_names, {"text_encoder": ["adapter-1", "adapter-2"], "unet": ["adapter-1", "adapter-2"]}
+ )
- pipe.set_adapters(["adapter-1", "adapter-2"])
- self.assertDictEqual(
- pipe.get_list_adapters(), {"unet": ["adapter-1", "adapter-2"], "text_encoder": ["adapter-1", "adapter-2"]}
- )
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ self.assertDictEqual(
+ pipe.get_list_adapters(),
+ {"unet": ["adapter-1", "adapter-2"], "text_encoder": ["adapter-1", "adapter-2"]},
+ )
- pipe.unet.add_adapter(unet_lora_config, "adapter-3")
- self.assertDictEqual(
- pipe.get_list_adapters(),
- {"unet": ["adapter-1", "adapter-2", "adapter-3"], "text_encoder": ["adapter-1", "adapter-2"]},
- )
+ pipe.unet.add_adapter(unet_lora_config, "adapter-3")
+ self.assertDictEqual(
+ pipe.get_list_adapters(),
+ {"unet": ["adapter-1", "adapter-2", "adapter-3"], "text_encoder": ["adapter-1", "adapter-2"]},
+ )
@unittest.skip("This is failing for now - need to investigate")
def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
@@ -947,32 +1011,35 @@ def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
- components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(self.torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ for scheduler_cls in [DDIMScheduler, LCMScheduler]:
+ components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(self.torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config)
- pipe.unet.add_adapter(unet_lora_config)
+ pipe.text_encoder.add_adapter(text_lora_config)
+ pipe.unet.add_adapter(unet_lora_config)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
-
- if self.has_two_text_encoders:
- pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
+ self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
- pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
+ self.assertTrue(
+ self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- if self.has_two_text_encoders:
- pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+ pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
+
+ if self.has_two_text_encoders:
+ pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
- # Just makes sure it works..
- _ = pipe(**inputs, generator=torch.manual_seed(0)).images
+ # Just makes sure it works..
+ _ = pipe(**inputs, generator=torch.manual_seed(0)).images
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
@@ -1574,6 +1641,97 @@ def test_sdxl_1_0_lora(self):
self.assertTrue(np.allclose(images, expected, atol=1e-4))
release_memory(pipe)
+ def test_sdxl_lcm_lora(self):
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+ pipe.enable_model_cpu_offload()
+
+ generator = torch.Generator().manual_seed(0)
+
+ lora_model_id = "latent-consistency/lcm-lora-sdxl"
+
+ pipe.load_lora_weights(lora_model_id)
+
+ image = pipe(
+ "masterpiece, best quality, mountain", generator=generator, num_inference_steps=4, guidance_scale=0.5
+ ).images[0]
+
+ expected_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_lora/sdxl_lcm_lora.png"
+ )
+
+ image_np = pipe.image_processor.pil_to_numpy(image)
+ expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
+
+ self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
+
+ pipe.unload_lora_weights()
+
+ release_memory(pipe)
+
+ def test_sdv1_5_lcm_lora(self):
+ pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ pipe.to("cuda")
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+ generator = torch.Generator().manual_seed(0)
+
+ lora_model_id = "latent-consistency/lcm-lora-sdv1-5"
+ pipe.load_lora_weights(lora_model_id)
+
+ image = pipe(
+ "masterpiece, best quality, mountain", generator=generator, num_inference_steps=4, guidance_scale=0.5
+ ).images[0]
+
+ expected_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_lora/sdv15_lcm_lora.png"
+ )
+
+ image_np = pipe.image_processor.pil_to_numpy(image)
+ expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
+
+ self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
+
+ pipe.unload_lora_weights()
+
+ release_memory(pipe)
+
+ def test_sdv1_5_lcm_lora_img2img(self):
+ pipe = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ pipe.to("cuda")
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+ init_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/img2img/fantasy_landscape.png"
+ )
+
+ generator = torch.Generator().manual_seed(0)
+
+ lora_model_id = "latent-consistency/lcm-lora-sdv1-5"
+ pipe.load_lora_weights(lora_model_id)
+
+ image = pipe(
+ "snowy mountain",
+ generator=generator,
+ image=init_image,
+ strength=0.5,
+ num_inference_steps=4,
+ guidance_scale=0.5,
+ ).images[0]
+
+ expected_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_lora/sdv15_lcm_lora_img2img.png"
+ )
+
+ image_np = pipe.image_processor.pil_to_numpy(image)
+ expected_image_np = pipe.image_processor.pil_to_numpy(expected_image)
+
+ self.assertTrue(np.allclose(image_np, expected_image_np, atol=1e-2))
+
+ pipe.unload_lora_weights()
+
+ release_memory(pipe)
+
def test_sdxl_1_0_lora_fusion(self):
generator = torch.Generator().manual_seed(0)
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index 80c97978723c..961147839461 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -196,11 +196,15 @@ def test_forward_with_norm_groups(self):
class ModelTesterMixin:
main_input_name = None # overwrite in model specific tester class
base_precision = 1e-3
+ forward_requires_fresh_args = False
def test_from_save_pretrained(self, expected_max_diff=5e-5):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ if self.forward_requires_fresh_args:
+ model = self.model_class(**self.init_dict)
+ else:
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
- model = self.model_class(**init_dict)
if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor()
model.to(torch_device)
@@ -214,11 +218,18 @@ def test_from_save_pretrained(self, expected_max_diff=5e-5):
new_model.to(torch_device)
with torch.no_grad():
- image = model(**inputs_dict)
+ if self.forward_requires_fresh_args:
+ image = model(**self.inputs_dict(0))
+ else:
+ image = model(**inputs_dict)
+
if isinstance(image, dict):
image = image.to_tuple()[0]
- new_image = new_model(**inputs_dict)
+ if self.forward_requires_fresh_args:
+ new_image = new_model(**self.inputs_dict(0))
+ else:
+ new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
@@ -275,8 +286,11 @@ def test_getattr_is_correct(self):
)
def test_set_xformers_attn_processor_for_determinism(self):
torch.use_deterministic_algorithms(False)
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**init_dict)
+ if self.forward_requires_fresh_args:
+ model = self.model_class(**self.init_dict)
+ else:
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
@@ -286,20 +300,42 @@ def test_set_xformers_attn_processor_for_determinism(self):
model.set_default_attn_processor()
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
- output = model(**inputs_dict)[0]
+ if self.forward_requires_fresh_args:
+ output = model(**self.inputs_dict(0))[0]
+ else:
+ output = model(**inputs_dict)[0]
model.enable_xformers_memory_efficient_attention()
assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
- output_2 = model(**inputs_dict)[0]
+ if self.forward_requires_fresh_args:
+ output_2 = model(**self.inputs_dict(0))[0]
+ else:
+ output_2 = model(**inputs_dict)[0]
+
+ model.set_attn_processor(XFormersAttnProcessor())
+ assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
+ with torch.no_grad():
+ if self.forward_requires_fresh_args:
+ output_3 = model(**self.inputs_dict(0))[0]
+ else:
+ output_3 = model(**inputs_dict)[0]
+
+ torch.use_deterministic_algorithms(True)
assert torch.allclose(output, output_2, atol=self.base_precision)
+ assert torch.allclose(output, output_3, atol=self.base_precision)
+ assert torch.allclose(output_2, output_3, atol=self.base_precision)
@require_torch_gpu
def test_set_attn_processor_for_determinism(self):
torch.use_deterministic_algorithms(False)
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**init_dict)
+ if self.forward_requires_fresh_args:
+ model = self.model_class(**self.init_dict)
+ else:
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
+
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
@@ -308,32 +344,34 @@ def test_set_attn_processor_for_determinism(self):
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
with torch.no_grad():
- output_1 = model(**inputs_dict)[0]
+ if self.forward_requires_fresh_args:
+ output_1 = model(**self.inputs_dict(0))[0]
+ else:
+ output_1 = model(**inputs_dict)[0]
model.set_default_attn_processor()
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
- output_2 = model(**inputs_dict)[0]
-
- model.enable_xformers_memory_efficient_attention()
- assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
- with torch.no_grad():
- model(**inputs_dict)[0]
+ if self.forward_requires_fresh_args:
+ output_2 = model(**self.inputs_dict(0))[0]
+ else:
+ output_2 = model(**inputs_dict)[0]
model.set_attn_processor(AttnProcessor2_0())
assert all(type(proc) == AttnProcessor2_0 for proc in model.attn_processors.values())
with torch.no_grad():
- output_4 = model(**inputs_dict)[0]
+ if self.forward_requires_fresh_args:
+ output_4 = model(**self.inputs_dict(0))[0]
+ else:
+ output_4 = model(**inputs_dict)[0]
model.set_attn_processor(AttnProcessor())
assert all(type(proc) == AttnProcessor for proc in model.attn_processors.values())
with torch.no_grad():
- output_5 = model(**inputs_dict)[0]
-
- model.set_attn_processor(XFormersAttnProcessor())
- assert all(type(proc) == XFormersAttnProcessor for proc in model.attn_processors.values())
- with torch.no_grad():
- output_6 = model(**inputs_dict)[0]
+ if self.forward_requires_fresh_args:
+ output_5 = model(**self.inputs_dict(0))[0]
+ else:
+ output_5 = model(**inputs_dict)[0]
torch.use_deterministic_algorithms(True)
@@ -341,12 +379,14 @@ def test_set_attn_processor_for_determinism(self):
assert torch.allclose(output_2, output_1, atol=self.base_precision)
assert torch.allclose(output_2, output_4, atol=self.base_precision)
assert torch.allclose(output_2, output_5, atol=self.base_precision)
- assert torch.allclose(output_2, output_6, atol=self.base_precision)
def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ if self.forward_requires_fresh_args:
+ model = self.model_class(**self.init_dict)
+ else:
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
- model = self.model_class(**init_dict)
if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor()
@@ -369,11 +409,17 @@ def test_from_save_pretrained_variant(self, expected_max_diff=5e-5):
new_model.to(torch_device)
with torch.no_grad():
- image = model(**inputs_dict)
+ if self.forward_requires_fresh_args:
+ image = model(**self.inputs_dict(0))
+ else:
+ image = model(**inputs_dict)
if isinstance(image, dict):
image = image.to_tuple()[0]
- new_image = new_model(**inputs_dict)
+ if self.forward_requires_fresh_args:
+ new_image = new_model(**self.inputs_dict(0))
+ else:
+ new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image.to_tuple()[0]
@@ -407,17 +453,26 @@ def test_from_save_pretrained_dtype(self):
assert new_model.dtype == dtype
def test_determinism(self, expected_max_diff=1e-5):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**init_dict)
+ if self.forward_requires_fresh_args:
+ model = self.model_class(**self.init_dict)
+ else:
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
- first = model(**inputs_dict)
+ if self.forward_requires_fresh_args:
+ first = model(**self.inputs_dict(0))
+ else:
+ first = model(**inputs_dict)
if isinstance(first, dict):
first = first.to_tuple()[0]
- second = model(**inputs_dict)
+ if self.forward_requires_fresh_args:
+ second = model(**self.inputs_dict(0))
+ else:
+ second = model(**inputs_dict)
if isinstance(second, dict):
second = second.to_tuple()[0]
@@ -550,15 +605,22 @@ def recursive_check(tuple_object, dict_object):
),
)
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ if self.forward_requires_fresh_args:
+ model = self.model_class(**self.init_dict)
+ else:
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict)
- model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
- outputs_dict = model(**inputs_dict)
- outputs_tuple = model(**inputs_dict, return_dict=False)
+ if self.forward_requires_fresh_args:
+ outputs_dict = model(**self.inputs_dict(0))
+ outputs_tuple = model(**self.inputs_dict(0), return_dict=False)
+ else:
+ outputs_dict = model(**inputs_dict)
+ outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py
index fe2bcdb0af35..3b698624ff87 100644
--- a/tests/models/test_models_vae.py
+++ b/tests/models/test_models_vae.py
@@ -16,11 +16,19 @@
import gc
import unittest
+import numpy as np
import torch
from parameterized import parameterized
-from diffusers import AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderTiny
+from diffusers import (
+ AsymmetricAutoencoderKL,
+ AutoencoderKL,
+ AutoencoderTiny,
+ ConsistencyDecoderVAE,
+ StableDiffusionPipeline,
+)
from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.loading_utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
@@ -30,6 +38,7 @@
torch_all_close,
torch_device,
)
+from diffusers.utils.torch_utils import randn_tensor
from .test_modeling_common import ModelTesterMixin, UNetTesterMixin
@@ -269,6 +278,70 @@ def test_outputs_equivalence(self):
pass
+class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
+ model_class = ConsistencyDecoderVAE
+ main_input_name = "sample"
+ base_precision = 1e-2
+ forward_requires_fresh_args = True
+
+ def inputs_dict(self, seed=None):
+ generator = torch.Generator("cpu")
+ if seed is not None:
+ generator.manual_seed(0)
+ image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
+
+ return {"sample": image, "generator": generator}
+
+ @property
+ def input_shape(self):
+ return (3, 32, 32)
+
+ @property
+ def output_shape(self):
+ return (3, 32, 32)
+
+ @property
+ def init_dict(self):
+ return {
+ "encoder_block_out_channels": [32, 64],
+ "encoder_in_channels": 3,
+ "encoder_out_channels": 4,
+ "encoder_down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ "decoder_add_attention": False,
+ "decoder_block_out_channels": [32, 64],
+ "decoder_down_block_types": [
+ "ResnetDownsampleBlock2D",
+ "ResnetDownsampleBlock2D",
+ ],
+ "decoder_downsample_padding": 1,
+ "decoder_in_channels": 7,
+ "decoder_layers_per_block": 1,
+ "decoder_norm_eps": 1e-05,
+ "decoder_norm_num_groups": 32,
+ "decoder_num_train_timesteps": 1024,
+ "decoder_out_channels": 6,
+ "decoder_resnet_time_scale_shift": "scale_shift",
+ "decoder_time_embedding_type": "learned",
+ "decoder_up_block_types": [
+ "ResnetUpsampleBlock2D",
+ "ResnetUpsampleBlock2D",
+ ],
+ "scaling_factor": 1,
+ "latent_channels": 4,
+ }
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return self.init_dict, self.inputs_dict()
+
+ @unittest.skip
+ def test_training(self):
+ ...
+
+ @unittest.skip
+ def test_ema_training(self):
+ ...
+
+
@slow
class AutoencoderTinyIntegrationTests(unittest.TestCase):
def tearDown(self):
@@ -721,3 +794,94 @@ def test_stable_diffusion_encode_sample(self, seed, expected_slice):
tolerance = 3e-3 if torch_device != "mps" else 1e-2
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
+
+
+@slow
+class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
+ def tearDown(self):
+ # clean up the VRAM after each test
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_encode_decode(self):
+ vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
+ vae.to(torch_device)
+
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/img2img/sketch-mountains-input.jpg"
+ ).resize((256, 256))
+ image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[
+ None, :, :, :
+ ].cuda()
+
+ latent = vae.encode(image).latent_dist.mean
+
+ sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
+
+ actual_output = sample[0, :2, :2, :2].flatten().cpu()
+ expected_output = torch.tensor([-0.0141, -0.0014, 0.0115, 0.0086, 0.1051, 0.1053, 0.1031, 0.1024])
+
+ assert torch_all_close(actual_output, expected_output, atol=5e-3)
+
+ def test_sd(self):
+ vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder") # TODO - update
+ pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
+ pipe.to(torch_device)
+
+ out = pipe(
+ "horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
+ ).images[0]
+
+ actual_output = out[:2, :2, :2].flatten().cpu()
+ expected_output = torch.tensor([0.7686, 0.8228, 0.6489, 0.7455, 0.8661, 0.8797, 0.8241, 0.8759])
+
+ assert torch_all_close(actual_output, expected_output, atol=5e-3)
+
+ def test_encode_decode_f16(self):
+ vae = ConsistencyDecoderVAE.from_pretrained(
+ "openai/consistency-decoder", torch_dtype=torch.float16
+ ) # TODO - update
+ vae.to(torch_device)
+
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
+ "/img2img/sketch-mountains-input.jpg"
+ ).resize((256, 256))
+ image = (
+ torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :]
+ .half()
+ .cuda()
+ )
+
+ latent = vae.encode(image).latent_dist.mean
+
+ sample = vae.decode(latent, generator=torch.Generator("cpu").manual_seed(0)).sample
+
+ actual_output = sample[0, :2, :2, :2].flatten().cpu()
+ expected_output = torch.tensor(
+ [-0.0111, -0.0125, -0.0017, -0.0007, 0.1257, 0.1465, 0.1450, 0.1471], dtype=torch.float16
+ )
+
+ assert torch_all_close(actual_output, expected_output, atol=5e-3)
+
+ def test_sd_f16(self):
+ vae = ConsistencyDecoderVAE.from_pretrained(
+ "openai/consistency-decoder", torch_dtype=torch.float16
+ ) # TODO - update
+ pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae, safety_checker=None
+ )
+ pipe.to(torch_device)
+
+ out = pipe(
+ "horse", num_inference_steps=2, output_type="pt", generator=torch.Generator("cpu").manual_seed(0)
+ ).images[0]
+
+ actual_output = out[:2, :2, :2].flatten().cpu()
+ expected_output = torch.tensor(
+ [0.0000, 0.0249, 0.0000, 0.0000, 0.1709, 0.2773, 0.0471, 0.1035], dtype=torch.float16
+ )
+
+ assert torch_all_close(actual_output, expected_output, atol=5e-3)
diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py
index baba8ba4d655..3c9390f2d1b6 100644
--- a/tests/pipelines/animatediff/test_animatediff.py
+++ b/tests/pipelines/animatediff/test_animatediff.py
@@ -220,6 +220,17 @@ def test_to_dtype(self):
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
+ def test_prompt_embeds(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs.pop("prompt")
+ inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device)
+ pipe(**inputs)
+
@slow
@require_torch_gpu
diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
index 82a2944aeda4..53702925534d 100644
--- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
+++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py
@@ -133,7 +133,7 @@ def test_lcm_onestep(self):
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5865, 0.2854, 0.2828, 0.7473, 0.6006, 0.4580, 0.4397, 0.6415, 0.6069])
+ expected_slice = np.array([0.4388, 0.3717, 0.2202, 0.7213, 0.6370, 0.3664, 0.5815, 0.6080, 0.4977])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_lcm_multistep(self):
@@ -150,7 +150,7 @@ def test_lcm_multistep(self):
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.4903, 0.3304, 0.3503, 0.5241, 0.5153, 0.4585, 0.3222, 0.4764, 0.4891])
+ expected_slice = np.array([0.4150, 0.3719, 0.2479, 0.6333, 0.6024, 0.3778, 0.5036, 0.5420, 0.4678])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_inference_batch_single_identical(self):
@@ -237,7 +237,7 @@ def test_lcm_onestep(self):
assert image.shape == (1, 512, 512, 3)
image_slice = image[0, -3:, -3:, -1].flatten()
- expected_slice = np.array([0.1025, 0.0911, 0.0984, 0.0981, 0.0901, 0.0918, 0.1055, 0.0940, 0.0730])
+ expected_slice = np.array([0.1950, 0.1961, 0.2308, 0.1786, 0.1837, 0.2320, 0.1898, 0.1885, 0.2309])
assert np.abs(image_slice - expected_slice).max() < 1e-3
def test_lcm_multistep(self):
@@ -253,5 +253,5 @@ def test_lcm_multistep(self):
assert image.shape == (1, 512, 512, 3)
image_slice = image[0, -3:, -3:, -1].flatten()
- expected_slice = np.array([0.01855, 0.01855, 0.01489, 0.01392, 0.01782, 0.01465, 0.01831, 0.02539, 0.0])
+ expected_slice = np.array([0.3756, 0.3816, 0.3767, 0.3718, 0.3739, 0.3735, 0.3863, 0.3803, 0.3563])
assert np.abs(image_slice - expected_slice).max() < 1e-3
diff --git a/tests/pipelines/pixart/test_pixart.py b/tests/pipelines/pixart/test_pixart.py
index 1797f7e0fec2..a04f4e1a8804 100644
--- a/tests/pipelines/pixart/test_pixart.py
+++ b/tests/pipelines/pixart/test_pixart.py
@@ -120,7 +120,6 @@ def test_save_load_optional_components(self):
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
- "mask_feature": False,
}
# set all optional components to None
@@ -155,7 +154,6 @@ def test_save_load_optional_components(self):
"generator": generator,
"num_inference_steps": num_inference_steps,
"output_type": output_type,
- "mask_feature": False,
}
output_loaded = pipe_loaded(**inputs)[0]
@@ -174,18 +172,99 @@ def test_inference(self):
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
- print(torch.from_numpy(image_slice.flatten()))
self.assertEqual(image.shape, (1, 8, 8, 3))
expected_slice = np.array([0.5303, 0.2658, 0.7979, 0.1182, 0.3304, 0.4608, 0.5195, 0.4261, 0.4675])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
+ def test_inference_non_square_images(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs, height=32, width=48).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ self.assertEqual(image.shape, (1, 32, 48, 3))
+ expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416])
+ max_diff = np.abs(image_slice.flatten() - expected_slice).max()
+ self.assertLessEqual(max_diff, 1e-3)
+
+ def test_inference_with_embeddings_and_multiple_images(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ prompt = inputs["prompt"]
+ generator = inputs["generator"]
+ num_inference_steps = inputs["num_inference_steps"]
+ output_type = inputs["output_type"]
+
+ prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(prompt)
+
+ # inputs with prompt converted to embeddings
+ inputs = {
+ "prompt_embeds": prompt_embeds,
+ "negative_prompt": None,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ "generator": generator,
+ "num_inference_steps": num_inference_steps,
+ "output_type": output_type,
+ "num_images_per_prompt": 2,
+ }
+
+ # set all optional components to None
+ for optional_component in pipe._optional_components:
+ setattr(pipe, optional_component, None)
+
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for optional_component in pipe._optional_components:
+ self.assertTrue(
+ getattr(pipe_loaded, optional_component) is None,
+ f"`{optional_component}` did not stay set to None after loading.",
+ )
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ generator = inputs["generator"]
+ num_inference_steps = inputs["num_inference_steps"]
+ output_type = inputs["output_type"]
+
+ # inputs with prompt converted to embeddings
+ inputs = {
+ "prompt_embeds": prompt_embeds,
+ "negative_prompt": None,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ "generator": generator,
+ "num_inference_steps": num_inference_steps,
+ "output_type": output_type,
+ "num_images_per_prompt": 2,
+ }
+
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
+ self.assertLess(max_diff, 1e-4)
+
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3)
-# TODO: needs to be updated.
@slow
@require_torch_gpu
class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
diff --git a/tests/pipelines/shap_e/test_shap_e.py b/tests/pipelines/shap_e/test_shap_e.py
index 7b95fdd9e669..c7792f097ed5 100644
--- a/tests/pipelines/shap_e/test_shap_e.py
+++ b/tests/pipelines/shap_e/test_shap_e.py
@@ -160,7 +160,7 @@ def get_dummy_inputs(self, device, seed=0):
"generator": generator,
"num_inference_steps": 1,
"frame_size": 32,
- "output_type": "np",
+ "output_type": "latent",
}
return inputs
@@ -176,24 +176,12 @@ def test_shap_e(self):
output = pipe(**self.get_dummy_inputs(device))
image = output.images[0]
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (20, 32, 32, 3)
-
- expected_slice = np.array(
- [
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- ]
- )
+ image = image.cpu().numpy()
+ image_slice = image[-3:, -3:]
+
+ assert image.shape == (32, 16)
+ expected_slice = np.array([-1.0000, -0.6241, 1.0000, -0.8978, -0.6866, 0.7876, -0.7473, -0.2874, 0.6103])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_batch_consistent(self):
diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py
index 055dbe7a97d4..ee8d9d07cd77 100644
--- a/tests/pipelines/shap_e/test_shap_e_img2img.py
+++ b/tests/pipelines/shap_e/test_shap_e_img2img.py
@@ -181,7 +181,7 @@ def get_dummy_inputs(self, device, seed=0):
"generator": generator,
"num_inference_steps": 1,
"frame_size": 32,
- "output_type": "np",
+ "output_type": "latent",
}
return inputs
@@ -197,22 +197,12 @@ def test_shap_e(self):
output = pipe(**self.get_dummy_inputs(device))
image = output.images[0]
- image_slice = image[0, -3:, -3:, -1]
+ image_slice = image[-3:, -3:].cpu().numpy()
- assert image.shape == (20, 32, 32, 3)
+ assert image.shape == (32, 16)
expected_slice = np.array(
- [
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- 0.00039216,
- ]
+ [-1.0, 0.40668195, 0.57322013, -0.9469888, 0.4283227, 0.30348337, -0.81094897, 0.74555075, 0.15342723]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
index ad77cc3e2b22..53284b80952c 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
@@ -31,6 +31,7 @@
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
+ LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionPipeline,
@@ -41,6 +42,7 @@
from diffusers.utils.testing_utils import (
CaptureLogger,
enable_full_determinism,
+ load_image,
load_numpy,
nightly,
numpy_cosine_similarity_distance,
@@ -107,12 +109,13 @@ class StableDiffusionPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
- def get_dummy_components(self):
+ def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=1,
sample_size=32,
+ time_cond_proj_dim=time_cond_proj_dim,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
@@ -196,6 +199,26 @@ def test_stable_diffusion_ddim(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ def test_stable_diffusion_lcm(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+
+ components = self.get_dummy_components(time_cond_proj_dim=256)
+ sd_pipe = StableDiffusionPipeline(**components)
+ sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ output = sd_pipe(**inputs)
+ image = output.images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+ expected_slice = np.array([0.3454, 0.5349, 0.5185, 0.2808, 0.4509, 0.4612, 0.4655, 0.3601, 0.4315])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
def test_stable_diffusion_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
@@ -1066,6 +1089,29 @@ def test_stable_diffusion_compile(self):
inputs["seed"] = seed
run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=inputs)
+ def test_stable_diffusion_lcm(self):
+ unet = UNet2DConditionModel.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", subfolder="unet")
+ sd_pipe = StableDiffusionPipeline.from_pretrained("Lykon/dreamshaper-7", unet=unet).to(torch_device)
+ sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_inputs(torch_device)
+ inputs["num_inference_steps"] = 6
+ inputs["output_type"] = "pil"
+
+ image = sd_pipe(**inputs).images[0]
+
+ expected_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_full/stable_diffusion_lcm.png"
+ )
+
+ image = sd_pipe.image_processor.pil_to_numpy(image)
+ expected_image = sd_pipe.image_processor.pil_to_numpy(expected_image)
+
+ max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
+
+ assert max_diff < 1e-2
+
@slow
@require_torch_gpu
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py
index cd688c3beb37..2e9d7c3b437b 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py
@@ -36,7 +36,6 @@
load_numpy,
nightly,
numpy_cosine_similarity_distance,
- print_tensor_test,
require_torch_gpu,
slow,
torch_device,
@@ -202,7 +201,6 @@ def test_stable_diffusion_img_variation_pipeline_default(self):
assert image.shape == (1, 512, 512, 3)
expected_slice = np.array([0.8449, 0.9079, 0.7571, 0.7873, 0.8348, 0.7010, 0.6694, 0.6873, 0.6138])
- print_tensor_test(image_slice)
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 1e-4
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
index 9e365e860f0e..12c6d8cf63d3 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
@@ -28,6 +28,7 @@
DDIMScheduler,
DPMSolverMultistepScheduler,
HeunDiscreteScheduler,
+ LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionImg2ImgPipeline,
@@ -103,11 +104,12 @@ class StableDiffusionImg2ImgPipelineFastTests(
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
- def get_dummy_components(self):
+ def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
+ time_cond_proj_dim=time_cond_proj_dim,
sample_size=32,
in_channels=4,
out_channels=4,
@@ -187,6 +189,23 @@ def test_stable_diffusion_img2img_default_case(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+ def test_stable_diffusion_img2img_default_case_lcm(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components(time_cond_proj_dim=256)
+ sd_pipe = StableDiffusionImg2ImgPipeline(**components)
+ sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 32, 32, 3)
+ expected_slice = np.array([0.5709, 0.4614, 0.4587, 0.5978, 0.5298, 0.6910, 0.6240, 0.5212, 0.5454])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
+
def test_stable_diffusion_img2img_negative_prompt(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
index 53818c91295a..5aa678bbf5f3 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
@@ -29,6 +29,7 @@
AutoencoderKL,
DDIMScheduler,
DPMSolverMultistepScheduler,
+ LCMScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
StableDiffusionInpaintPipeline,
@@ -106,10 +107,11 @@ class StableDiffusionInpaintPipelineFastTests(
image_latents_params = frozenset([])
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"mask", "masked_image_latents"})
- def get_dummy_components(self):
+ def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
+ time_cond_proj_dim=time_cond_proj_dim,
layers_per_block=2,
sample_size=32,
in_channels=9,
@@ -252,6 +254,23 @@ def test_stable_diffusion_inpaint(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ def test_stable_diffusion_inpaint_lcm(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components(time_cond_proj_dim=256)
+ sd_pipe = StableDiffusionInpaintPipeline(**components)
+ sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+ expected_slice = np.array([0.4931, 0.5988, 0.4569, 0.5556, 0.6650, 0.5087, 0.5966, 0.5358, 0.5269])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
def test_stable_diffusion_inpaint_image_tensor(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -334,11 +353,12 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli
image_params = frozenset([])
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
- def get_dummy_components(self):
+ def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
+ time_cond_proj_dim=time_cond_proj_dim,
sample_size=32,
in_channels=4,
out_channels=4,
@@ -427,6 +447,23 @@ def test_stable_diffusion_inpaint(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ def test_stable_diffusion_inpaint_lcm(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components(time_cond_proj_dim=256)
+ sd_pipe = StableDiffusionInpaintPipeline(**components)
+ sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+ expected_slice = np.array([0.6240, 0.5355, 0.5649, 0.5378, 0.5374, 0.6242, 0.5132, 0.5347, 0.5396])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
def test_stable_diffusion_inpaint_2_images(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
index d2d00d9c0110..95fbb658fe5e 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py
@@ -27,12 +27,20 @@
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
+ LCMScheduler,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
UniPCMultistepScheduler,
)
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ load_image,
+ numpy_cosine_similarity_distance,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
@@ -56,11 +64,12 @@ class StableDiffusionXLPipelineFastTests(
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
- def get_dummy_components(self):
+ def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(2, 4),
layers_per_block=2,
+ time_cond_proj_dim=time_cond_proj_dim,
sample_size=32,
in_channels=4,
out_channels=4,
@@ -155,6 +164,23 @@ def test_stable_diffusion_xl_euler(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ def test_stable_diffusion_xl_euler_lcm(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components(time_cond_proj_dim=256)
+ sd_pipe = StableDiffusionXLPipeline(**components)
+ sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+ expected_slice = np.array([0.4917, 0.6555, 0.4348, 0.5219, 0.7324, 0.4855, 0.5168, 0.5447, 0.5156])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
def test_stable_diffusion_xl_prompt_embeds(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**components)
@@ -890,3 +916,32 @@ def test_stable_diffusion_xl_save_from_pretrained(self):
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
+
+
+@slow
+class StableDiffusionXLPipelineIntegrationTests(unittest.TestCase):
+ def test_stable_diffusion_lcm(self):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel.from_pretrained(
+ "latent-consistency/lcm-ssd-1b", torch_dtype=torch.float16, variant="fp16"
+ )
+ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
+ "segmind/SSD-1B", unet=unet, torch_dtype=torch.float16, variant="fp16"
+ ).to(torch_device)
+ sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ prompt = "a red car standing on the side of the street"
+
+ image = sd_pipe(prompt, num_inference_steps=4, guidance_scale=8.0).images[0]
+
+ expected_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/lcm_full/stable_diffusion_ssd_1b_lcm.png"
+ )
+
+ image = sd_pipe.image_processor.pil_to_numpy(image)
+ expected_image = sd_pipe.image_processor.pil_to_numpy(expected_image)
+
+ max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
+
+ assert max_diff < 1e-2
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py
index c3fb397956fa..55779e5f060d 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py
@@ -24,6 +24,7 @@
AutoencoderKL,
AutoencoderTiny,
EulerDiscreteScheduler,
+ LCMScheduler,
StableDiffusionXLImg2ImgPipeline,
UNet2DConditionModel,
)
@@ -57,7 +58,7 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
{"add_text_embeds", "add_time_ids", "add_neg_time_ids"}
)
- def get_dummy_components(self, skip_first_text_encoder=False):
+ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
@@ -65,6 +66,7 @@ def get_dummy_components(self, skip_first_text_encoder=False):
sample_size=32,
in_channels=4,
out_channels=4,
+ time_cond_proj_dim=time_cond_proj_dim,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below
@@ -172,6 +174,24 @@ def test_stable_diffusion_xl_img2img_euler(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ def test_stable_diffusion_xl_img2img_euler_lcm(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components(time_cond_proj_dim=256)
+ sd_pipe = StableDiffusionXLImg2ImgPipeline(**components)
+ sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.config)
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 32, 32, 3)
+
+ expected_slice = np.array([0.5604, 0.4352, 0.4717, 0.5844, 0.5101, 0.6704, 0.6290, 0.5460, 0.5286])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
index aa607c23ffda..54c750f997b6 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py
@@ -28,6 +28,7 @@
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
+ LCMScheduler,
StableDiffusionXLInpaintPipeline,
UNet2DConditionModel,
UniPCMultistepScheduler,
@@ -61,7 +62,7 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
}
)
- def get_dummy_components(self, skip_first_text_encoder=False):
+ def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
@@ -69,6 +70,7 @@ def get_dummy_components(self, skip_first_text_encoder=False):
sample_size=32,
in_channels=4,
out_channels=4,
+ time_cond_proj_dim=time_cond_proj_dim,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below
@@ -209,6 +211,24 @@ def test_stable_diffusion_xl_inpaint_euler(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ def test_stable_diffusion_xl_inpaint_euler_lcm(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components(time_cond_proj_dim=256)
+ sd_pipe = StableDiffusionXLInpaintPipeline(**components)
+ sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.config)
+ sd_pipe = sd_pipe.to(device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = sd_pipe(**inputs).images
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+
+ expected_slice = np.array([0.6611, 0.5569, 0.5531, 0.5471, 0.5918, 0.6393, 0.5074, 0.5468, 0.5185])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py
index 875fd787c8b0..42c90e47af80 100644
--- a/tests/pipelines/test_pipelines.py
+++ b/tests/pipelines/test_pipelines.py
@@ -14,7 +14,6 @@
# limitations under the License.
import gc
-import glob
import json
import os
import random
@@ -57,7 +56,7 @@
UniPCMultistepScheduler,
logging,
)
-from diffusers.pipelines.pipeline_utils import _get_pipeline_class, variant_compatible_siblings
+from diffusers.pipelines.pipeline_utils import _get_pipeline_class
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import (
CONFIG_NAME,
@@ -1505,28 +1504,15 @@ def test_name_or_path(self):
assert sd.name_or_path == tmpdirname
- def test_warning_no_variant_available(self):
+ def test_error_no_variant_available(self):
variant = "fp16"
- with self.assertWarns(FutureWarning) as warning_context:
- cached_folder = StableDiffusionPipeline.download(
+ with self.assertRaises(ValueError) as error_context:
+ _ = StableDiffusionPipeline.download(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", variant=variant
)
- assert "but no such modeling files are available" in str(warning_context.warning)
- assert variant in str(warning_context.warning)
-
- def get_all_filenames(directory):
- filenames = glob.glob(directory + "/**", recursive=True)
- filenames = [f for f in filenames if os.path.isfile(f)]
- return filenames
-
- filenames = get_all_filenames(str(cached_folder))
-
- all_model_files, variant_model_files = variant_compatible_siblings(filenames, variant=variant)
-
- # make sure that none of the model names are variant model names
- assert len(variant_model_files) == 0
- assert len(all_model_files) > 0
+ assert "but no such modeling files are available" in str(error_context.exception)
+ assert variant in str(error_context.exception)
def test_pipe_to(self):
unet = self.dummy_cond_unet()
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 1795c83b58a1..b9fe4d190f23 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -493,7 +493,7 @@ def _test_inference_batch_single_identical(
assert output_batch[0].shape[0] == batch_size
- max_diff = np.abs(output_batch[0][0] - output[0][0]).max()
+ max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
assert max_diff < expected_max_diff
def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
@@ -702,7 +702,7 @@ def _test_attention_slicing_forward_pass(
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
if test_mean_pixel_difference:
- assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])
+ assert_mean_pixel_difference(to_np(output_with_slicing[0]), to_np(output_without_slicing[0]))
@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
diff --git a/tests/schedulers/test_scheduler_lcm.py b/tests/schedulers/test_scheduler_lcm.py
index 48b68fa47ddc..f7d511ff0573 100644
--- a/tests/schedulers/test_scheduler_lcm.py
+++ b/tests/schedulers/test_scheduler_lcm.py
@@ -230,7 +230,7 @@ def test_full_loop_onestep(self):
result_mean = torch.mean(torch.abs(sample))
# TODO: get expected sum and mean
- assert abs(result_sum.item() - 18.7097) < 1e-2
+ assert abs(result_sum.item() - 18.7097) < 1e-3
assert abs(result_mean.item() - 0.0244) < 1e-3
def test_full_loop_multistep(self):
@@ -240,5 +240,5 @@ def test_full_loop_multistep(self):
result_mean = torch.mean(torch.abs(sample))
# TODO: get expected sum and mean
- assert abs(result_sum.item() - 280.5618) < 1e-2
- assert abs(result_mean.item() - 0.3653) < 1e-3
+ assert abs(result_sum.item() - 197.7616) < 1e-3
+ assert abs(result_mean.item() - 0.2575) < 1e-3