diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml
index c0ae87521fdc..718c67731bd5 100644
--- a/.github/workflows/benchmark.yml
+++ b/.github/workflows/benchmark.yml
@@ -31,7 +31,6 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install pandas peft
diff --git a/.github/workflows/build_docker_images.yml b/.github/workflows/build_docker_images.yml
index 3d061f88c241..82ef885b240e 100644
--- a/.github/workflows/build_docker_images.yml
+++ b/.github/workflows/build_docker_images.yml
@@ -20,7 +20,7 @@ env:
jobs:
test-build-docker-images:
- runs-on: ubuntu-latest
+ runs-on: [ self-hosted, intel-cpu, 8-cpu, ci ]
if: github.event_name == 'pull_request'
steps:
- name: Set up Docker Buildx
@@ -50,7 +50,7 @@ jobs:
if: steps.file_changes.outputs.all != ''
build-and-push-docker-images:
- runs-on: ubuntu-latest
+ runs-on: [ self-hosted, intel-cpu, 8-cpu, ci ]
if: github.event_name != 'pull_request'
permissions:
@@ -73,13 +73,13 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v3
-
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v1
- name: Login to Docker Hub
uses: docker/login-action@v2
with:
username: ${{ env.REGISTRY }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
-
- name: Build and push
uses: docker/build-push-action@v3
with:
diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml
index d489da8e48eb..2f73c66de829 100644
--- a/.github/workflows/nightly_tests.yml
+++ b/.github/workflows/nightly_tests.yml
@@ -70,7 +70,6 @@ jobs:
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
@@ -131,7 +130,6 @@ jobs:
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
@@ -202,7 +200,6 @@ jobs:
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
@@ -262,7 +259,6 @@ jobs:
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
diff --git a/.github/workflows/pr_test_fetcher.yml b/.github/workflows/pr_test_fetcher.yml
index 6a7da2cb73d0..4dbb118c6092 100644
--- a/.github/workflows/pr_test_fetcher.yml
+++ b/.github/workflows/pr_test_fetcher.yml
@@ -32,7 +32,6 @@ jobs:
fetch-depth: 0
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
- name: Environment
@@ -89,7 +88,6 @@ jobs:
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pip install -e [quality,test]
python -m pip install accelerate
@@ -147,7 +145,6 @@ jobs:
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pip install -e [quality,test]
diff --git a/.github/workflows/pr_test_peft_backend.yml b/.github/workflows/pr_test_peft_backend.yml
index c7a6ea4fb7c7..b4915a3bf4d2 100644
--- a/.github/workflows/pr_test_peft_backend.yml
+++ b/.github/workflows/pr_test_peft_backend.yml
@@ -32,9 +32,7 @@ jobs:
python -m pip install --upgrade pip
pip install .[quality]
- name: Check quality
- run: |
- ruff check examples tests src utils scripts
- ruff format examples tests src utils scripts --check
+ run: make quality
- name: Check if failure
if: ${{ failure() }}
run: |
@@ -53,7 +51,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install .[quality]
- - name: Check quality
+ - name: Check repo consistency
run: |
python utils/check_copies.py
python utils/check_dummies.py
@@ -73,7 +71,7 @@ jobs:
name: LoRA - ${{ matrix.lib-versions }}
- runs-on: docker-cpu
+ runs-on: [ self-hosted, intel-cpu, 8-cpu, ci ]
container:
image: diffusers/diffusers-pytorch-cpu
@@ -91,11 +89,10 @@ jobs:
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
if [ "${{ matrix.lib-versions }}" == "main" ]; then
- python -m uv pip install -U peft@git+https://github.com/huggingface/peft.git
+ python -m pip install -U peft@git+https://github.com/huggingface/peft.git
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git
python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
else
@@ -110,7 +107,7 @@ jobs:
- name: Run fast PyTorch LoRA CPU tests with PEFT backend
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v \
--make-reports=tests_${{ matrix.config.report }} \
tests/lora/
diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml
index 7ec4ffa713b8..b1bed6568aa4 100644
--- a/.github/workflows/pr_tests.yml
+++ b/.github/workflows/pr_tests.yml
@@ -40,9 +40,7 @@ jobs:
python -m pip install --upgrade pip
pip install .[quality]
- name: Check quality
- run: |
- ruff check examples tests src utils scripts
- ruff format examples tests src utils scripts --check
+ run: make quality
- name: Check if failure
if: ${{ failure() }}
run: |
@@ -61,7 +59,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install .[quality]
- - name: Check quality
+ - name: Check repo consistency
run: |
python utils/check_copies.py
python utils/check_dummies.py
@@ -79,22 +77,22 @@ jobs:
config:
- name: Fast PyTorch Pipeline CPU tests
framework: pytorch_pipelines
- runner: docker-cpu
+ runner: [ self-hosted, intel-cpu, 32-cpu, 256-ram, ci ]
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_pipelines
- name: Fast PyTorch Models & Schedulers CPU tests
framework: pytorch_models
- runner: docker-cpu
+ runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_models_schedulers
- name: Fast Flax CPU tests
framework: flax
- runner: docker-cpu
+ runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
image: diffusers/diffusers-flax-cpu
report: flax_cpu
- name: PyTorch Example CPU tests
framework: pytorch_examples
- runner: docker-cpu
+ runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
image: diffusers/diffusers-pytorch-cpu
report: torch_example_cpu
@@ -118,7 +116,6 @@ jobs:
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install accelerate
@@ -132,7 +129,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
+ python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/pipelines
@@ -141,7 +138,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_models' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and not Dependency" \
--make-reports=tests_${{ matrix.config.report }} \
tests/models tests/schedulers tests/others
@@ -150,7 +147,7 @@ jobs:
if: ${{ matrix.config.framework == 'flax' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Flax" \
--make-reports=tests_${{ matrix.config.report }} \
tests
@@ -160,7 +157,7 @@ jobs:
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install peft
- python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples
@@ -183,7 +180,7 @@ jobs:
config:
- name: Hub tests for models, schedulers, and pipelines
framework: hub_tests_pytorch
- runner: docker-cpu
+ runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
image: diffusers/diffusers-pytorch-cpu
report: torch_hub
@@ -207,7 +204,6 @@ jobs:
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml
index 0a316c90dfed..a6cb123a7035 100644
--- a/.github/workflows/push_tests.yml
+++ b/.github/workflows/push_tests.yml
@@ -60,7 +60,7 @@ jobs:
runs-on: [single-gpu, nvidia-gpu, t4, ci]
container:
image: diffusers/diffusers-pytorch-cuda
- options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0 --privileged
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
@@ -69,9 +69,14 @@ jobs:
- name: NVIDIA-SMI
run: |
nvidia-smi
+ - name: Tailscale
+ uses: huggingface/tailscale-action@v1
+ with:
+ authkey: ${{ secrets.TAILSCALE_SSH_AUTHKEY }}
+ slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}
+ slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
@@ -88,6 +93,12 @@ jobs:
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
+ - name: Tailscale Wait
+ if: ${{ failure() || runner.debug == '1' }}
+ uses: huggingface/tailscale-action@v1
+ with:
+ waitForSSH: true
+ authkey: ${{ secrets.TAILSCALE_SSH_AUTHKEY }}
- name: Failure short reports
if: ${{ failure() }}
run: |
@@ -121,7 +132,6 @@ jobs:
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
@@ -171,11 +181,10 @@ jobs:
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
- python -m uv pip install peft@git+https://github.com/huggingface/peft.git
+ python -m pip install -U peft@git+https://github.com/huggingface/peft.git
- name: Environment
run: |
@@ -222,7 +231,6 @@ jobs:
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
@@ -270,7 +278,6 @@ jobs:
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
@@ -430,4 +437,4 @@ jobs:
uses: actions/upload-artifact@v2
with:
name: examples_test_reports
- path: reports
\ No newline at end of file
+ path: reports
diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml
index 6b01577041b2..7c50da7b5c34 100644
--- a/.github/workflows/push_tests_fast.yml
+++ b/.github/workflows/push_tests_fast.yml
@@ -29,22 +29,22 @@ jobs:
config:
- name: Fast PyTorch CPU tests on Ubuntu
framework: pytorch
- runner: docker-cpu
+ runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu
- name: Fast Flax CPU tests on Ubuntu
framework: flax
- runner: docker-cpu
+ runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
image: diffusers/diffusers-flax-cpu
report: flax_cpu
- name: Fast ONNXRuntime CPU tests on Ubuntu
framework: onnxruntime
- runner: docker-cpu
+ runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
image: diffusers/diffusers-onnxruntime-cpu
report: onnx_cpu
- name: PyTorch Example CPU tests on Ubuntu
framework: pytorch_examples
- runner: docker-cpu
+ runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
image: diffusers/diffusers-pytorch-cpu
report: torch_example_cpu
@@ -68,7 +68,6 @@ jobs:
- name: Install dependencies
run: |
- apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
@@ -81,7 +80,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
@@ -90,7 +89,7 @@ jobs:
if: ${{ matrix.config.framework == 'flax' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Flax" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
@@ -99,7 +98,7 @@ jobs:
if: ${{ matrix.config.framework == 'onnxruntime' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
@@ -109,7 +108,7 @@ jobs:
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install peft
- python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
+ python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples
diff --git a/.github/workflows/update_metadata.yml b/.github/workflows/update_metadata.yml
new file mode 100644
index 000000000000..f91fa29a1ab9
--- /dev/null
+++ b/.github/workflows/update_metadata.yml
@@ -0,0 +1,30 @@
+name: Update Diffusers metadata
+
+on:
+ workflow_dispatch:
+ push:
+ branches:
+ - main
+ - update_diffusers_metadata*
+
+jobs:
+ update_metadata:
+ runs-on: ubuntu-22.04
+ defaults:
+ run:
+ shell: bash -l {0}
+
+ steps:
+ - uses: actions/checkout@v3
+
+ - name: Setup environment
+ run: |
+ pip install --upgrade pip
+ pip install datasets pandas
+ pip install .[torch]
+
+ - name: Update metadata
+ env:
+ HUGGING_FACE_HUB_TOKEN: ${{ secrets.SAYAK_HF_TOKEN }}
+ run: |
+ python utils/update_metadata.py --commit_sha ${{ github.sha }}
diff --git a/Makefile b/Makefile
index c92285b48c71..9af2e8b1a5c9 100644
--- a/Makefile
+++ b/Makefile
@@ -42,6 +42,7 @@ repo-consistency:
quality:
ruff check $(check_dirs) setup.py
ruff format --check $(check_dirs) setup.py
+ doc-builder style src/diffusers docs/source --max_len 119 --check_only
python utils/check_doc_toc.py
# Format source code automatically and check is there are any problems left that need manual fixing
@@ -55,6 +56,7 @@ extra_style_checks:
style:
ruff check $(check_dirs) setup.py --fix
ruff format $(check_dirs) setup.py
+ doc-builder style src/diffusers docs/source --max_len 119
${MAKE} autogenerate_code
${MAKE} extra_style_checks
diff --git a/docker/diffusers-flax-cpu/Dockerfile b/docker/diffusers-flax-cpu/Dockerfile
index 36d036e34e5f..005c0f9caacf 100644
--- a/docker/diffusers-flax-cpu/Dockerfile
+++ b/docker/diffusers-flax-cpu/Dockerfile
@@ -12,6 +12,7 @@ RUN apt update && \
curl \
ca-certificates \
libsndfile1-dev \
+ libgl1 \
python3.8 \
python3-pip \
python3.8-venv && \
diff --git a/docker/diffusers-flax-tpu/Dockerfile b/docker/diffusers-flax-tpu/Dockerfile
index 78d5f972a753..05ea22488ab9 100644
--- a/docker/diffusers-flax-tpu/Dockerfile
+++ b/docker/diffusers-flax-tpu/Dockerfile
@@ -12,6 +12,7 @@ RUN apt update && \
curl \
ca-certificates \
libsndfile1-dev \
+ libgl1 \
python3.8 \
python3-pip \
python3.8-venv && \
diff --git a/docker/diffusers-onnxruntime-cpu/Dockerfile b/docker/diffusers-onnxruntime-cpu/Dockerfile
index 0d032d91e5eb..b60b467b7485 100644
--- a/docker/diffusers-onnxruntime-cpu/Dockerfile
+++ b/docker/diffusers-onnxruntime-cpu/Dockerfile
@@ -12,6 +12,7 @@ RUN apt update && \
curl \
ca-certificates \
libsndfile1-dev \
+ libgl1 \
python3.8 \
python3-pip \
python3.8-venv && \
diff --git a/docker/diffusers-onnxruntime-cuda/Dockerfile b/docker/diffusers-onnxruntime-cuda/Dockerfile
index 34e611df7257..16a0d76460f4 100644
--- a/docker/diffusers-onnxruntime-cuda/Dockerfile
+++ b/docker/diffusers-onnxruntime-cuda/Dockerfile
@@ -12,6 +12,7 @@ RUN apt update && \
curl \
ca-certificates \
libsndfile1-dev \
+ libgl1 \
python3.8 \
python3-pip \
python3.8-venv && \
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 29bd65fb4dba..ea5d1a021c94 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -24,14 +24,12 @@
title: Tutorials
- sections:
- sections:
- - local: using-diffusers/loading_overview
- title: Overview
- local: using-diffusers/loading
- title: Load pipelines, models, and schedulers
- - local: using-diffusers/schedulers
- title: Load and compare different schedulers
+ title: Load pipelines
- local: using-diffusers/custom_pipeline_overview
title: Load community pipelines and components
+ - local: using-diffusers/schedulers
+ title: Load schedulers and models
- local: using-diffusers/using_safetensors
title: Load safetensors
- local: using-diffusers/other-formats
@@ -71,7 +69,7 @@
- local: using-diffusers/control_brightness
title: Control image brightness
- local: using-diffusers/weighted_prompts
- title: Prompt weighting
+ title: Prompt techniques
- local: using-diffusers/freeu
title: Improve generation quality with FreeU
title: Techniques
@@ -86,6 +84,8 @@
title: Kandinsky
- local: using-diffusers/controlnet
title: ControlNet
+ - local: using-diffusers/t2i_adapter
+ title: T2I-Adapter
- local: using-diffusers/shap-e
title: Shap-E
- local: using-diffusers/diffedit
@@ -170,6 +170,8 @@
title: Token merging
- local: optimization/deepcache
title: DeepCache
+ - local: optimization/tgate
+ title: TGATE
title: General optimizations
- sections:
- local: using-diffusers/stable_diffusion_jax_how_to
@@ -280,6 +282,10 @@
title: ControlNet
- local: api/pipelines/controlnet_sdxl
title: ControlNet with Stable Diffusion XL
+ - local: api/pipelines/controlnetxs
+ title: ControlNet-XS
+ - local: api/pipelines/controlnetxs_sdxl
+ title: ControlNet-XS with Stable Diffusion XL
- local: api/pipelines/dance_diffusion
title: Dance Diffusion
- local: api/pipelines/ddim
@@ -358,7 +364,7 @@
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
- local: api/pipelines/stable_diffusion/adapter
- title: Stable Diffusion T2I-Adapter
+ title: T2I-Adapter
- local: api/pipelines/stable_diffusion/gligen
title: GLIGEN (Grounded Language-to-Image Generation)
title: Stable Diffusion
diff --git a/docs/source/en/api/pipelines/audioldm2.md b/docs/source/en/api/pipelines/audioldm2.md
index b29bea96e324..ac4459c60706 100644
--- a/docs/source/en/api/pipelines/audioldm2.md
+++ b/docs/source/en/api/pipelines/audioldm2.md
@@ -20,7 +20,8 @@ The abstract of the paper is the following:
*Although audio generation shares commonalities across different types of audio, such as speech, music, and sound effects, designing models for each type requires careful consideration of specific objectives and biases that can significantly differ from those of other types. To bring us closer to a unified perspective of audio generation, this paper proposes a framework that utilizes the same learning method for speech, music, and sound effect generation. Our framework introduces a general representation of audio, called "language of audio" (LOA). Any audio can be translated into LOA based on AudioMAE, a self-supervised pre-trained representation learning model. In the generation process, we translate any modalities into LOA by using a GPT-2 model, and we perform self-supervised audio generation learning with a latent diffusion model conditioned on LOA. The proposed framework naturally brings advantages such as in-context learning abilities and reusable self-supervised pretrained AudioMAE and latent diffusion models. Experiments on the major benchmarks of text-to-audio, text-to-music, and text-to-speech demonstrate state-of-the-art or competitive performance against previous approaches. Our code, pretrained model, and demo are available at [this https URL](https://audioldm.github.io/audioldm2).*
-This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi). The original codebase can be found at [haoheliu/audioldm2](https://github.com/haoheliu/audioldm2).
+This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi) and [Nguyễn Công Tú Anh](https://github.com/tuanh123789). The original codebase can be
+found at [haoheliu/audioldm2](https://github.com/haoheliu/audioldm2).
## Tips
@@ -36,6 +37,8 @@ See table below for details on the three checkpoints:
| [audioldm2](https://huggingface.co/cvssp/audioldm2) | Text-to-audio | 350M | 1.1B | 1150k |
| [audioldm2-large](https://huggingface.co/cvssp/audioldm2-large) | Text-to-audio | 750M | 1.5B | 1150k |
| [audioldm2-music](https://huggingface.co/cvssp/audioldm2-music) | Text-to-music | 350M | 1.1B | 665k |
+| [audioldm2-gigaspeech](https://huggingface.co/anhnct/audioldm2_gigaspeech) | Text-to-speech | 350M | 1.1B |10k |
+| [audioldm2-ljspeech](https://huggingface.co/anhnct/audioldm2_ljspeech) | Text-to-speech | 350M | 1.1B | |
### Constructing a prompt
@@ -53,7 +56,7 @@ See table below for details on the three checkpoints:
* The quality of the generated waveforms can vary significantly based on the seed. Try generating with different seeds until you find a satisfactory generation.
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
-The following example demonstrates how to construct good music generation using the aforementioned tips: [example](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2Pipeline.__call__.example).
+The following example demonstrates how to construct good music and speech generation using the aforementioned tips: [example](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2Pipeline.__call__.example).
diff --git a/examples/research_projects/controlnetxs/README.md b/docs/source/en/api/pipelines/controlnetxs.md
similarity index 61%
rename from examples/research_projects/controlnetxs/README.md
rename to docs/source/en/api/pipelines/controlnetxs.md
index 72ed91c01db2..2d4ae7b8ce46 100644
--- a/examples/research_projects/controlnetxs/README.md
+++ b/docs/source/en/api/pipelines/controlnetxs.md
@@ -1,3 +1,15 @@
+
+
# ControlNet-XS
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
@@ -12,5 +24,16 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## StableDiffusionControlNetXSPipeline
+[[autodoc]] StableDiffusionControlNetXSPipeline
+ - all
+ - __call__
-> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
\ No newline at end of file
+## StableDiffusionPipelineOutput
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/examples/research_projects/controlnetxs/README_sdxl.md b/docs/source/en/api/pipelines/controlnetxs_sdxl.md
similarity index 56%
rename from examples/research_projects/controlnetxs/README_sdxl.md
rename to docs/source/en/api/pipelines/controlnetxs_sdxl.md
index d401c1e76698..31075c0ef96a 100644
--- a/examples/research_projects/controlnetxs/README_sdxl.md
+++ b/docs/source/en/api/pipelines/controlnetxs_sdxl.md
@@ -1,3 +1,15 @@
+
+
# ControlNet-XS with Stable Diffusion XL
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
@@ -12,4 +24,22 @@ Here's the overview from the [project page](https://vislearn.github.io/ControlNe
This model was contributed by [UmerHA](https://twitter.com/UmerHAdil). ❤️
-> 🧠 Make sure to check out the Schedulers [guide](https://huggingface.co/docs/diffusers/main/en/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
\ No newline at end of file
+
+
+🧪 Many of the SDXL ControlNet checkpoints are experimental, and there is a lot of room for improvement. Feel free to open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) and leave us feedback on how we can improve!
+
+
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## StableDiffusionXLControlNetXSPipeline
+[[autodoc]] StableDiffusionXLControlNetXSPipeline
+ - all
+ - __call__
+
+## StableDiffusionPipelineOutput
+[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
diff --git a/docs/source/en/api/pipelines/stable_diffusion/adapter.md b/docs/source/en/api/pipelines/stable_diffusion/adapter.md
index aa38e3d9741f..ca42fdc83984 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/adapter.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/adapter.md
@@ -10,9 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Text-to-Image Generation with Adapter Conditioning
-
-## Overview
+# T2I-Adapter
[T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.08453) by Chong Mou, Xintao Wang, Liangbin Xie, Jian Zhang, Zhongang Qi, Ying Shan, Xiaohu Qie.
@@ -24,236 +22,26 @@ The abstract of the paper is the following:
This model was contributed by the community contributor [HimariO](https://github.com/HimariO) ❤️ .
-## Available Pipelines:
-
-| Pipeline | Tasks | Demo
-|---|---|:---:|
-| [StableDiffusionAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning* | -
-| [StableDiffusionXLAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning on StableDiffusion-XL* | -
-
-## Usage example with the base model of StableDiffusion-1.4/1.5
-
-In the following we give a simple example of how to use a *T2I-Adapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5.
-All adapters use the same pipeline.
-
- 1. Images are first converted into the appropriate *control image* format.
- 2. The *control image* and *prompt* are passed to the [`StableDiffusionAdapterPipeline`].
-
-Let's have a look at a simple example using the [Color Adapter](https://huggingface.co/TencentARC/t2iadapter_color_sd14v1).
-
-```python
-from diffusers.utils import load_image, make_image_grid
-
-image = load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png")
-```
-
-![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_ref.png)
-
-
-Then we can create our color palette by simply resizing it to 8 by 8 pixels and then scaling it back to original size.
-
-```python
-from PIL import Image
-
-color_palette = image.resize((8, 8))
-color_palette = color_palette.resize((512, 512), resample=Image.Resampling.NEAREST)
-```
-
-Let's take a look at the processed image.
-
-![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_palette.png)
-
-
-Next, create the adapter pipeline
-
-```py
-import torch
-from diffusers import StableDiffusionAdapterPipeline, T2IAdapter
-
-adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1", torch_dtype=torch.float16)
-pipe = StableDiffusionAdapterPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4",
- adapter=adapter,
- torch_dtype=torch.float16,
-)
-pipe.to("cuda")
-```
-
-Finally, pass the prompt and control image to the pipeline
-
-```py
-# fix the random seed, so you will get the same result as the example
-generator = torch.Generator("cuda").manual_seed(7)
-
-out_image = pipe(
- "At night, glowing cubes in front of the beach",
- image=color_palette,
- generator=generator,
-).images[0]
-make_image_grid([image, color_palette, out_image], rows=1, cols=3)
-```
-
-![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/color_output.png)
-
-## Usage example with the base model of StableDiffusion-XL
-
-In the following we give a simple example of how to use a *T2I-Adapter* checkpoint with Diffusers for inference based on StableDiffusion-XL.
-All adapters use the same pipeline.
-
- 1. Images are first downloaded into the appropriate *control image* format.
- 2. The *control image* and *prompt* are passed to the [`StableDiffusionXLAdapterPipeline`].
-
-Let's have a look at a simple example using the [Sketch Adapter](https://huggingface.co/Adapter/t2iadapter/tree/main/sketch_sdxl_1.0).
-
-```python
-from diffusers.utils import load_image, make_image_grid
-
-sketch_image = load_image("https://huggingface.co/Adapter/t2iadapter/resolve/main/sketch.png").convert("L")
-```
-
-![img](https://huggingface.co/Adapter/t2iadapter/resolve/main/sketch.png)
-
-Then, create the adapter pipeline
-
-```py
-import torch
-from diffusers import (
- T2IAdapter,
- StableDiffusionXLAdapterPipeline,
- DDPMScheduler
-)
-
-model_id = "stabilityai/stable-diffusion-xl-base-1.0"
-adapter = T2IAdapter.from_pretrained("Adapter/t2iadapter", subfolder="sketch_sdxl_1.0", torch_dtype=torch.float16, adapter_type="full_adapter_xl")
-scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
-
-pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
- model_id, adapter=adapter, safety_checker=None, torch_dtype=torch.float16, variant="fp16", scheduler=scheduler
-)
-
-pipe.to("cuda")
-```
-
-Finally, pass the prompt and control image to the pipeline
-
-```py
-# fix the random seed, so you will get the same result as the example
-generator = torch.Generator().manual_seed(42)
-
-sketch_image_out = pipe(
- prompt="a photo of a dog in real world, high quality",
- negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
- image=sketch_image,
- generator=generator,
- guidance_scale=7.5
-).images[0]
-make_image_grid([sketch_image, sketch_image_out], rows=1, cols=2)
-```
-
-![img](https://huggingface.co/Adapter/t2iadapter/resolve/main/sketch_output.png)
-
-## Available checkpoints
-
-Non-diffusers checkpoints can be found under [TencentARC/T2I-Adapter](https://huggingface.co/TencentARC/T2I-Adapter/tree/main/models).
-
-### T2I-Adapter with Stable Diffusion 1.4
-
-| Model Name | Control Image Overview| Control Image Example | Generated Image Example |
-|---|---|---|---|
-|[TencentARC/t2iadapter_color_sd14v1](https://huggingface.co/TencentARC/t2iadapter_color_sd14v1)
*Trained with spatial color palette* | An image with 8x8 color palette.|||
-|[TencentARC/t2iadapter_canny_sd14v1](https://huggingface.co/TencentARC/t2iadapter_canny_sd14v1)
*Trained with canny edge detection* | A monochrome image with white edges on a black background.|||
-|[TencentARC/t2iadapter_sketch_sd14v1](https://huggingface.co/TencentARC/t2iadapter_sketch_sd14v1)
*Trained with [PidiNet](https://github.com/zhuoinoulu/pidinet) edge detection* | A hand-drawn monochrome image with white outlines on a black background.|||
-|[TencentARC/t2iadapter_depth_sd14v1](https://huggingface.co/TencentARC/t2iadapter_depth_sd14v1)
*Trained with Midas depth estimation* | A grayscale image with black representing deep areas and white representing shallow areas.|||
-|[TencentARC/t2iadapter_openpose_sd14v1](https://huggingface.co/TencentARC/t2iadapter_openpose_sd14v1)
*Trained with OpenPose bone image* | A [OpenPose bone](https://github.com/CMU-Perceptual-Computing-Lab/openpose) image.|||
-|[TencentARC/t2iadapter_keypose_sd14v1](https://huggingface.co/TencentARC/t2iadapter_keypose_sd14v1)
*Trained with mmpose skeleton image* | A [mmpose skeleton](https://github.com/open-mmlab/mmpose) image.|||
-|[TencentARC/t2iadapter_seg_sd14v1](https://huggingface.co/TencentARC/t2iadapter_seg_sd14v1)
*Trained with semantic segmentation* | An [custom](https://github.com/TencentARC/T2I-Adapter/discussions/25) segmentation protocol image.|| |
-|[TencentARC/t2iadapter_canny_sd15v2](https://huggingface.co/TencentARC/t2iadapter_canny_sd15v2)||
-|[TencentARC/t2iadapter_depth_sd15v2](https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2)||
-|[TencentARC/t2iadapter_sketch_sd15v2](https://huggingface.co/TencentARC/t2iadapter_sketch_sd15v2)||
-|[TencentARC/t2iadapter_zoedepth_sd15v1](https://huggingface.co/TencentARC/t2iadapter_zoedepth_sd15v1)||
-|[Adapter/t2iadapter, subfolder='sketch_sdxl_1.0'](https://huggingface.co/Adapter/t2iadapter/tree/main/sketch_sdxl_1.0)||
-|[Adapter/t2iadapter, subfolder='canny_sdxl_1.0'](https://huggingface.co/Adapter/t2iadapter/tree/main/canny_sdxl_1.0)||
-|[Adapter/t2iadapter, subfolder='openpose_sdxl_1.0'](https://huggingface.co/Adapter/t2iadapter/tree/main/openpose_sdxl_1.0)||
-
-## Combining multiple adapters
-
-[`MultiAdapter`] can be used for applying multiple conditionings at once.
-
-Here we use the keypose adapter for the character posture and the depth adapter for creating the scene.
-
-```py
-from diffusers.utils import load_image, make_image_grid
-
-cond_keypose = load_image(
- "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png"
-)
-cond_depth = load_image(
- "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png"
-)
-cond = [cond_keypose, cond_depth]
-
-prompt = ["A man walking in an office room with a nice view"]
-```
-
-The two control images look as such:
-
-![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_sample_input.png)
-![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/depth_sample_input.png)
-
-
-`MultiAdapter` combines keypose and depth adapters.
-
-`adapter_conditioning_scale` balances the relative influence of the different adapters.
-
-```py
-import torch
-from diffusers import StableDiffusionAdapterPipeline, MultiAdapter, T2IAdapter
-
-adapters = MultiAdapter(
- [
- T2IAdapter.from_pretrained("TencentARC/t2iadapter_keypose_sd14v1"),
- T2IAdapter.from_pretrained("TencentARC/t2iadapter_depth_sd14v1"),
- ]
-)
-adapters = adapters.to(torch.float16)
-
-pipe = StableDiffusionAdapterPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4",
- torch_dtype=torch.float16,
- adapter=adapters,
-).to("cuda")
-
-image = pipe(prompt, cond, adapter_conditioning_scale=[0.8, 0.8]).images[0]
-make_image_grid([cond_keypose, cond_depth, image], rows=1, cols=3)
-```
-
-![img](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/t2i-adapter/keypose_depth_sample_output.png)
-
-
-## T2I-Adapter vs ControlNet
-
-T2I-Adapter is similar to [ControlNet](https://huggingface.co/docs/diffusers/main/en/api/pipelines/controlnet).
-T2I-Adapter uses a smaller auxiliary network which is only run once for the entire diffusion process.
-However, T2I-Adapter performs slightly worse than ControlNet.
-
## StableDiffusionAdapterPipeline
+
[[autodoc]] StableDiffusionAdapterPipeline
- - all
- - __call__
- - enable_attention_slicing
- - disable_attention_slicing
- - enable_vae_slicing
- - disable_vae_slicing
- - enable_xformers_memory_efficient_attention
- - disable_xformers_memory_efficient_attention
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
## StableDiffusionXLAdapterPipeline
+
[[autodoc]] StableDiffusionXLAdapterPipeline
- - all
- - __call__
- - enable_attention_slicing
- - disable_attention_slicing
- - enable_vae_slicing
- - disable_vae_slicing
- - enable_xformers_memory_efficient_attention
- - disable_xformers_memory_efficient_attention
+ - all
+ - __call__
+ - enable_attention_slicing
+ - disable_attention_slicing
+ - enable_vae_slicing
+ - disable_vae_slicing
+ - enable_xformers_memory_efficient_attention
+ - disable_xformers_memory_efficient_attention
diff --git a/docs/source/en/optimization/tgate.md b/docs/source/en/optimization/tgate.md
new file mode 100644
index 000000000000..d208ddfa8411
--- /dev/null
+++ b/docs/source/en/optimization/tgate.md
@@ -0,0 +1,182 @@
+# T-GATE
+
+[T-GATE](https://github.com/HaozheLiu-ST/T-GATE/tree/main) accelerates inference for [Stable Diffusion](../api/pipelines/stable_diffusion/overview), [PixArt](../api/pipelines/pixart), and [Latency Consistency Model](../api/pipelines/latent_consistency_models.md) pipelines by skipping the cross-attention calculation once it converges. This method doesn't require any additional training and it can speed up inference from 10-50%. T-GATE is also compatible with other optimization methods like [DeepCache](./deepcache).
+
+Before you begin, make sure you install T-GATE.
+
+```bash
+pip install tgate
+pip install -U pytorch diffusers transformers accelerate DeepCache
+```
+
+
+To use T-GATE with a pipeline, you need to use its corresponding loader.
+
+| Pipeline | T-GATE Loader |
+|---|---|
+| PixArt | TgatePixArtLoader |
+| Stable Diffusion XL | TgateSDXLLoader |
+| Stable Diffusion XL + DeepCache | TgateSDXLDeepCacheLoader |
+| Stable Diffusion | TgateSDLoader |
+| Stable Diffusion + DeepCache | TgateSDDeepCacheLoader |
+
+Next, create a `TgateLoader` with a pipeline, the gate step (the time step to stop calculating the cross attention), and the number of inference steps. Then call the `tgate` method on the pipeline with a prompt, gate step, and the number of inference steps.
+
+Let's see how to enable this for several different pipelines.
+
+
+
+
+Accelerate `PixArtAlphaPipeline` with T-GATE:
+
+```py
+import torch
+from diffusers import PixArtAlphaPipeline
+from tgate import TgatePixArtLoader
+
+pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
+
+gate_step = 8
+inference_step = 25
+pipe = TgatePixArtLoader(
+ pipe,
+ gate_step=gate_step,
+ num_inference_steps=inference_step,
+).to("cuda")
+
+image = pipe.tgate(
+ "An alpaca made of colorful building blocks, cyberpunk.",
+ gate_step=gate_step,
+ num_inference_steps=inference_step,
+).images[0]
+```
+
+
+
+Accelerate `StableDiffusionXLPipeline` with T-GATE:
+
+```py
+import torch
+from diffusers import StableDiffusionXLPipeline
+from diffusers import DPMSolverMultistepScheduler
+from tgate import TgateSDXLLoader
+
+pipe = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ variant="fp16",
+ use_safetensors=True,
+)
+pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+
+gate_step = 10
+inference_step = 25
+pipe = TgateSDXLLoader(
+ pipe,
+ gate_step=gate_step,
+ num_inference_steps=inference_step,
+).to("cuda")
+
+image = pipe.tgate(
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
+ gate_step=gate_step,
+ num_inference_steps=inference_step
+).images[0]
+```
+
+
+
+Accelerate `StableDiffusionXLPipeline` with [DeepCache](https://github.com/horseee/DeepCache) and T-GATE:
+
+```py
+import torch
+from diffusers import StableDiffusionXLPipeline
+from diffusers import DPMSolverMultistepScheduler
+from tgate import TgateSDXLDeepCacheLoader
+
+pipe = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+ variant="fp16",
+ use_safetensors=True,
+)
+pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
+
+gate_step = 10
+inference_step = 25
+pipe = TgateSDXLDeepCacheLoader(
+ pipe,
+ cache_interval=3,
+ cache_branch_id=0,
+).to("cuda")
+
+image = pipe.tgate(
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
+ gate_step=gate_step,
+ num_inference_steps=inference_step
+).images[0]
+```
+
+
+
+Accelerate `latent-consistency/lcm-sdxl` with T-GATE:
+
+```py
+import torch
+from diffusers import StableDiffusionXLPipeline
+from diffusers import UNet2DConditionModel, LCMScheduler
+from diffusers import DPMSolverMultistepScheduler
+from tgate import TgateSDXLLoader
+
+unet = UNet2DConditionModel.from_pretrained(
+ "latent-consistency/lcm-sdxl",
+ torch_dtype=torch.float16,
+ variant="fp16",
+)
+pipe = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ unet=unet,
+ torch_dtype=torch.float16,
+ variant="fp16",
+)
+pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
+
+gate_step = 1
+inference_step = 4
+pipe = TgateSDXLLoader(
+ pipe,
+ gate_step=gate_step,
+ num_inference_steps=inference_step,
+ lcm=True
+).to("cuda")
+
+image = pipe.tgate(
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k.",
+ gate_step=gate_step,
+ num_inference_steps=inference_step
+).images[0]
+```
+
+
+
+T-GATE also supports [`StableDiffusionPipeline`] and [PixArt-alpha/PixArt-LCM-XL-2-1024-MS](https://hf.co/PixArt-alpha/PixArt-LCM-XL-2-1024-MS).
+
+## Benchmarks
+| Model | MACs | Param | Latency | Zero-shot 10K-FID on MS-COCO |
+|-----------------------|----------|-----------|---------|---------------------------|
+| SD-1.5 | 16.938T | 859.520M | 7.032s | 23.927 |
+| SD-1.5 w/ T-GATE | 9.875T | 815.557M | 4.313s | 20.789 |
+| SD-2.1 | 38.041T | 865.785M | 16.121s | 22.609 |
+| SD-2.1 w/ T-GATE | 22.208T | 815.433 M | 9.878s | 19.940 |
+| SD-XL | 149.438T | 2.570B | 53.187s | 24.628 |
+| SD-XL w/ T-GATE | 84.438T | 2.024B | 27.932s | 22.738 |
+| Pixart-Alpha | 107.031T | 611.350M | 61.502s | 38.669 |
+| Pixart-Alpha w/ T-GATE | 65.318T | 462.585M | 37.867s | 35.825 |
+| DeepCache (SD-XL) | 57.888T | - | 19.931s | 23.755 |
+| DeepCache w/ T-GATE | 43.868T | - | 14.666s | 23.999 |
+| LCM (SD-XL) | 11.955T | 2.570B | 3.805s | 25.044 |
+| LCM w/ T-GATE | 11.171T | 2.024B | 3.533s | 25.028 |
+| LCM (Pixart-Alpha) | 8.563T | 611.350M | 4.733s | 36.086 |
+| LCM w/ T-GATE | 7.623T | 462.585M | 4.543s | 37.048 |
+
+The latency is tested on an NVIDIA 1080TI, MACs and Params are calculated with [calflops](https://github.com/MrYxJ/calculate-flops.pytorch), and the FID is calculated with [PytorchFID](https://github.com/mseitzer/pytorch-fid).
diff --git a/docs/source/en/training/distributed_inference.md b/docs/source/en/training/distributed_inference.md
index 008dc3002bb5..40876a26e6a3 100644
--- a/docs/source/en/training/distributed_inference.md
+++ b/docs/source/en/training/distributed_inference.md
@@ -52,6 +52,76 @@ To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](h
+### Device placement
+
+> [!WARNING]
+> This feature is experimental and its APIs might change in the future.
+
+With Accelerate, you can use the `device_map` to determine how to distribute the models of a pipeline across multiple devices. This is useful in situations where you have more than one GPU.
+
+For example, if you have two 8GB GPUs, then using [`~DiffusionPipeline.enable_model_cpu_offload`] may not work so well because:
+
+* it only works on a single GPU
+* a single model might not fit on a single GPU ([`~DiffusionPipeline.enable_sequential_cpu_offload`] might work but it will be extremely slow and it is also limited to a single GPU)
+
+To make use of both GPUs, you can use the "balanced" device placement strategy which splits the models across all available GPUs.
+
+> [!WARNING]
+> Only the "balanced" strategy is supported at the moment, and we plan to support additional mapping strategies in the future.
+
+```diff
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(
+- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True,
++ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True, device_map="balanced"
+)
+image = pipeline("a dog").images[0]
+image
+```
+
+You can also pass a dictionary to enforce the maximum GPU memory that can be used on each device:
+
+```diff
+from diffusers import DiffusionPipeline
+import torch
+
+max_memory = {0:"1GB", 1:"1GB"}
+pipeline = DiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+ use_safetensors=True,
+ device_map="balanced",
++ max_memory=max_memory
+)
+image = pipeline("a dog").images[0]
+image
+```
+
+If a device is not present in `max_memory`, then it will be completely ignored and will not participate in the device placement.
+
+By default, Diffusers uses the maximum memory of all devices. If the models don't fit on the GPUs, they are offloaded to the CPU. If the CPU doesn't have enough memory, then you might see an error. In that case, you could defer to using [`~DiffusionPipeline.enable_sequential_cpu_offload`] and [`~DiffusionPipeline.enable_model_cpu_offload`].
+
+Call [`~DiffusionPipeline.reset_device_map`] to reset the `device_map` of a pipeline. This is also necessary if you want to use methods like `to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`] on a pipeline that was device-mapped.
+
+```py
+pipeline.reset_device_map()
+```
+
+Once a pipeline has been device-mapped, you can also access its device map via `hf_device_map`:
+
+```py
+print(pipeline.hf_device_map)
+```
+
+An example device map would look like so:
+
+
+```bash
+{'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}
+```
+
## PyTorch Distributed
PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables data parallelism.
diff --git a/docs/source/en/using-diffusers/callback.md b/docs/source/en/using-diffusers/callback.md
index 296245c3abe2..3f3e8dae9f2d 100644
--- a/docs/source/en/using-diffusers/callback.md
+++ b/docs/source/en/using-diffusers/callback.md
@@ -148,9 +148,9 @@ pipeline = AutoPipelineForText2Image.from_pretrained(
use_safetensors=True
).to("cuda")
-image = pipe(
- prompt = "A croissant shaped like a cute bear."
- negative_prompt = "Deformed, ugly, bad anatomy"
+image = pipeline(
+ prompt="A croissant shaped like a cute bear.",
+ negative_prompt="Deformed, ugly, bad anatomy",
callback_on_step_end=decode_tensors,
callback_on_step_end_tensor_inputs=["latents"],
).images[0]
diff --git a/docs/source/en/using-diffusers/custom_pipeline_overview.md b/docs/source/en/using-diffusers/custom_pipeline_overview.md
index 3c03ddc732f1..0b6bb53f10d6 100644
--- a/docs/source/en/using-diffusers/custom_pipeline_overview.md
+++ b/docs/source/en/using-diffusers/custom_pipeline_overview.md
@@ -16,17 +16,19 @@ specific language governing permissions and limitations under the License.
## Community pipelines
-Community pipelines are any [`DiffusionPipeline`] class that are different from the original implementation as specified in their paper (for example, the [`StableDiffusionControlNetPipeline`] corresponds to the [Text-to-Image Generation with ControlNet Conditioning](https://arxiv.org/abs/2302.05543) paper). They provide additional functionality or extend the original implementation of a pipeline.
+Community pipelines are any [`DiffusionPipeline`] class that are different from the original paper implementation (for example, the [`StableDiffusionControlNetPipeline`] corresponds to the [Text-to-Image Generation with ControlNet Conditioning](https://arxiv.org/abs/2302.05543) paper). They provide additional functionality or extend the original implementation of a pipeline.
-There are many cool community pipelines like [Speech to Image](https://github.com/huggingface/diffusers/tree/main/examples/community#speech-to-image) or [Composable Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#composable-stable-diffusion), and you can find all the official community pipelines [here](https://github.com/huggingface/diffusers/tree/main/examples/community).
+There are many cool community pipelines like [Marigold Depth Estimation](https://github.com/huggingface/diffusers/tree/main/examples/community#marigold-depth-estimation) or [InstantID](https://github.com/huggingface/diffusers/tree/main/examples/community#instantid-pipeline), and you can find all the official community pipelines [here](https://github.com/huggingface/diffusers/tree/main/examples/community).
-To load any community pipeline on the Hub, pass the repository id of the community pipeline to the `custom_pipeline` argument and the model repository where you'd like to load the pipeline weights and components from. For example, the example below loads a dummy pipeline from [`hf-internal-testing/diffusers-dummy-pipeline`](https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py) and the pipeline weights and components from [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32):
+There are two types of community pipelines, those stored on the Hugging Face Hub and those stored on Diffusers GitHub repository. Hub pipelines are completely customizable (scheduler, models, pipeline code, etc.) while Diffusers GitHub pipelines are only limited to custom pipeline code. Refer to this [table](./contribute_pipeline#share-your-pipeline) for a more detailed comparison of Hub vs GitHub community pipelines.
-
+
+> [!WARNING]
+> By loading a community pipeline from the Hugging Face Hub, you are trusting that the code you are loading is safe. Make sure to inspect the code online before loading and running it automatically!
```py
from diffusers import DiffusionPipeline
@@ -36,7 +38,10 @@ pipeline = DiffusionPipeline.from_pretrained(
)
```
-Loading an official community pipeline is similar, but you can mix loading weights from an official repository id and pass pipeline components directly. The example below loads the community [CLIP Guided Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#clip-guided-stable-diffusion) pipeline, and you can pass the CLIP model components directly to it:
+
+
+
+To load a GitHub community pipeline, pass the repository id of the community pipeline to the `custom_pipeline` argument and the model repository where you you'd like to load the pipeline weights and components from. You can also load model components directly. The example below loads the community [CLIP Guided Stable Diffusion](https://github.com/huggingface/diffusers/tree/main/examples/community#clip-guided-stable-diffusion) pipeline and the CLIP model components.
```py
from diffusers import DiffusionPipeline
@@ -56,9 +61,12 @@ pipeline = DiffusionPipeline.from_pretrained(
)
```
+
+
+
### Load from a local file
-Community pipelines can also be loaded from a local file if you pass a file path instead. The path to the passed directory must contain a `pipeline.py` file that contains the pipeline class in order to successfully load it.
+Community pipelines can also be loaded from a local file if you pass a file path instead. The path to the passed directory must contain a pipeline.py file that contains the pipeline class.
```py
pipeline = DiffusionPipeline.from_pretrained(
@@ -77,7 +85,7 @@ By default, community pipelines are loaded from the latest stable version of Dif
-For example, to load from the `main` branch:
+For example, to load from the main branch:
```py
pipeline = DiffusionPipeline.from_pretrained(
@@ -93,7 +101,7 @@ pipeline = DiffusionPipeline.from_pretrained(
-For example, to load from a previous version of Diffusers like `v0.25.0`:
+For example, to load from a previous version of Diffusers like v0.25.0:
```py
pipeline = DiffusionPipeline.from_pretrained(
@@ -109,8 +117,49 @@ pipeline = DiffusionPipeline.from_pretrained(
+### Load with from_pipe
-For more information about community pipelines, take a look at the [Community pipelines](custom_pipeline_examples) guide for how to use them and if you're interested in adding a community pipeline check out the [How to contribute a community pipeline](contribute_pipeline) guide!
+Community pipelines can also be loaded with the [`~DiffusionPipeline.from_pipe`] method which allows you to load and reuse multiple pipelines without any additional memory overhead (learn more in the [Reuse a pipeline](./loading#reuse-a-pipeline) guide). The memory requirement is determined by the largest single pipeline loaded.
+
+For example, let's load a community pipeline that supports [long prompts with weighting](https://github.com/huggingface/diffusers/tree/main/examples/community#long-prompt-weighting-stable-diffusion) from a Stable Diffusion pipeline.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipe_sd = DiffusionPipeline.from_pretrained("emilianJR/CyberRealistic_V3", torch_dtype=torch.float16)
+pipe_sd.to("cuda")
+# load long prompt weighting pipeline
+pipe_lpw = DiffusionPipeline.from_pipe(
+ pipe_sd,
+ custom_pipeline="lpw_stable_diffusion",
+).to("cuda")
+
+prompt = "cat, hiding in the leaves, ((rain)), zazie rainyday, beautiful eyes, macro shot, colorful details, natural lighting, amazing composition, subsurface scattering, amazing textures, filmic, soft light, ultra-detailed eyes, intricate details, detailed texture, light source contrast, dramatic shadows, cinematic light, depth of field, film grain, noise, dark background, hyperrealistic dslr film still, dim volumetric cinematic lighting"
+neg_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
+generator = torch.Generator(device="cpu").manual_seed(20)
+out_lpw = pipe_lpw(
+ prompt,
+ negative_prompt=neg_prompt,
+ width=512,
+ height=512,
+ max_embeddings_multiples=3,
+ num_inference_steps=50,
+ generator=generator,
+ ).images[0]
+out_lpw
+```
+
+
+
+
+
Stable Diffusion with long prompt weighting
+
+
+
+
Stable Diffusion
+
+
## Community components
@@ -118,7 +167,7 @@ Community components allow users to build pipelines that may have customized com
This section shows how users should use community components to build a community pipeline.
-You'll use the [showlab/show-1-base](https://huggingface.co/showlab/show-1-base) pipeline checkpoint as an example. So, let's start loading the components:
+You'll use the [showlab/show-1-base](https://huggingface.co/showlab/show-1-base) pipeline checkpoint as an example.
1. Import and load the text encoder from Transformers:
@@ -152,17 +201,17 @@ In steps 4 and 5, the custom [UNet](https://github.com/showlab/Show-1/blob/main/
-4. Now you'll load a [custom UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py), which in this example, has already been implemented in the `showone_unet_3d_condition.py` [script](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) for your convenience. You'll notice the `UNet3DConditionModel` class name is changed to `ShowOneUNet3DConditionModel` because [`UNet3DConditionModel`] already exists in Diffusers. Any components needed for the `ShowOneUNet3DConditionModel` class should be placed in the `showone_unet_3d_condition.py` script.
+4. Now you'll load a [custom UNet](https://github.com/showlab/Show-1/blob/main/showone/models/unet_3d_condition.py), which in this example, has already been implemented in [showone_unet_3d_condition.py](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) for your convenience. You'll notice the [`UNet3DConditionModel`] class name is changed to `ShowOneUNet3DConditionModel` because [`UNet3DConditionModel`] already exists in Diffusers. Any components needed for the `ShowOneUNet3DConditionModel` class should be placed in showone_unet_3d_condition.py.
-Once this is done, you can initialize the UNet:
+ Once this is done, you can initialize the UNet:
-```python
-from showone_unet_3d_condition import ShowOneUNet3DConditionModel
+ ```python
+ from showone_unet_3d_condition import ShowOneUNet3DConditionModel
-unet = ShowOneUNet3DConditionModel.from_pretrained(pipe_id, subfolder="unet")
-```
+ unet = ShowOneUNet3DConditionModel.from_pretrained(pipe_id, subfolder="unet")
+ ```
-5. Finally, you'll load the custom pipeline code. For this example, it has already been created for you in the `pipeline_t2v_base_pixel.py` [script](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/pipeline_t2v_base_pixel.py). This script contains a custom `TextToVideoIFPipeline` class for generating videos from text. Just like the custom UNet, any code needed for the custom pipeline to work should go in the `pipeline_t2v_base_pixel.py` script.
+5. Finally, you'll load the custom pipeline code. For this example, it has already been created for you in [pipeline_t2v_base_pixel.py](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/pipeline_t2v_base_pixel.py). This script contains a custom `TextToVideoIFPipeline` class for generating videos from text. Just like the custom UNet, any code needed for the custom pipeline to work should go in pipeline_t2v_base_pixel.py.
Once everything is in place, you can initialize the `TextToVideoIFPipeline` with the `ShowOneUNet3DConditionModel`:
@@ -187,13 +236,16 @@ Push the pipeline to the Hub to share with the community!
pipeline.push_to_hub("custom-t2v-pipeline")
```
-After the pipeline is successfully pushed, you need a couple of changes:
+After the pipeline is successfully pushed, you need to make a few changes:
-1. Change the `_class_name` attribute in [`model_index.json`](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/model_index.json#L2) to `"pipeline_t2v_base_pixel"` and `"TextToVideoIFPipeline"`.
-2. Upload `showone_unet_3d_condition.py` to the `unet` [directory](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py).
-3. Upload `pipeline_t2v_base_pixel.py` to the pipeline base [directory](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py).
+1. Change the `_class_name` attribute in [model_index.json](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/model_index.json#L2) to `"pipeline_t2v_base_pixel"` and `"TextToVideoIFPipeline"`.
+2. Upload `showone_unet_3d_condition.py` to the [unet](https://huggingface.co/sayakpaul/show-1-base-with-code/blob/main/unet/showone_unet_3d_condition.py) subfolder.
+3. Upload `pipeline_t2v_base_pixel.py` to the pipeline [repository](https://huggingface.co/sayakpaul/show-1-base-with-code/tree/main).
-To run inference, simply add the `trust_remote_code` argument while initializing the pipeline to handle all the "magic" behind the scenes.
+To run inference, add the `trust_remote_code` argument while initializing the pipeline to handle all the "magic" behind the scenes.
+
+> [!WARNING]
+> As an additional precaution with `trust_remote_code=True`, we strongly encourage you to pass a commit hash to the `revision` parameter in [`~DiffusionPipeline.from_pretrained`] to make sure the code hasn't been updated with some malicious new lines of code (unless you fully trust the model owners).
```python
from diffusers import DiffusionPipeline
@@ -221,10 +273,9 @@ video_frames = pipeline(
).frames
```
-As an additional reference example, you can refer to the repository structure of [stabilityai/japanese-stable-diffusion-xl](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl/), that makes use of the `trust_remote_code` feature:
+As an additional reference, take a look at the repository structure of [stabilityai/japanese-stable-diffusion-xl](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl/) which also uses the `trust_remote_code` feature.
```python
-
from diffusers import DiffusionPipeline
import torch
@@ -232,14 +283,4 @@ pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/japanese-stable-diffusion-xl", trust_remote_code=True
)
pipeline.to("cuda")
-
-# if using torch < 2.0
-# pipeline.enable_xformers_memory_efficient_attention()
-
-prompt = "柴犬、カラフルアート"
-
-image = pipeline(prompt=prompt).images[0]
```
-
-> [!TIP]
-> When using `trust_remote_code=True`, it is also strongly encouraged to pass a commit hash as a `revision` to make sure the author of the models did not update the code with some malicious new lines (unless you fully trust the authors of the models).
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/ip_adapter.md b/docs/source/en/using-diffusers/ip_adapter.md
index 4ae403538d2b..dc64b2548529 100644
--- a/docs/source/en/using-diffusers/ip_adapter.md
+++ b/docs/source/en/using-diffusers/ip_adapter.md
@@ -362,14 +362,12 @@ IP-Adapter's image prompting and compatibility with other adapters and models ma
### Face model
-Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces:
+Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces from the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) repository:
* [ip-adapter-full-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-full-face_sd15.safetensors) is conditioned with images of cropped faces and removed backgrounds
* [ip-adapter-plus-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-plus-face_sd15.safetensors) uses patch embeddings and is conditioned with images of cropped faces
-> [!TIP]
->
-> [IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) is a face-specific IP-Adapter trained with face ID embeddings instead of CLIP image embeddings, allowing you to generate more consistent faces in different contexts and styles. Try out this popular [community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#ip-adapter-face-id) and see how it compares to the other face IP-Adapters.
+Additionally, Diffusers supports all IP-Adapter checkpoints trained with face embeddings extracted by `insightface` face models. Supported models are from the [h94/IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) repository.
For face models, use the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) checkpoint. It is also recommended to use [`DDIMScheduler`] or [`EulerDiscreteScheduler`] for face models.
@@ -411,6 +409,56 @@ image
+To use IP-Adapter FaceID models, first extract face embeddings with `insightface`. Then pass the list of tensors to the pipeline as `ip_adapter_image_embeds`.
+
+```py
+import torch
+from diffusers import StableDiffusionPipeline, DDIMScheduler
+from diffusers.utils import load_image
+from insightface.app import FaceAnalysis
+
+pipeline = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ torch_dtype=torch.float16,
+).to("cuda")
+pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
+pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sd15.bin", image_encoder_folder=None)
+pipeline.set_ip_adapter_scale(0.6)
+
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl1.png")
+
+ref_images_embeds = []
+app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+app.prepare(ctx_id=0, det_size=(640, 640))
+image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB)
+faces = app.get(image)
+image = torch.from_numpy(faces[0].normed_embedding)
+ref_images_embeds.append(image.unsqueeze(0))
+ref_images_embeds = torch.stack(ref_images_embeds, dim=0).unsqueeze(0)
+neg_ref_images_embeds = torch.zeros_like(ref_images_embeds)
+id_embeds = torch.cat([neg_ref_images_embeds, ref_images_embeds]).to(dtype=torch.float16, device="cuda"))
+
+generator = torch.Generator(device="cpu").manual_seed(42)
+
+images = pipeline(
+ prompt="A photo of a girl",
+ ip_adapter_image_embeds=[id_embeds],
+ negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
+ num_inference_steps=20, num_images_per_prompt=1,
+ generator=generator
+).images
+```
+
+Both IP-Adapter FaceID Plus and Plus v2 models require CLIP image embeddings. You can prepare face embeddings as shown previously, then you can extract and pass CLIP embeddings to the hidden image projection layers.
+
+```py
+clip_embeds = pipeline.prepare_ip_adapter_image_embeds([ip_adapter_images], None, torch.device("cuda"), num_images, True)[0]
+
+pipeline.unet.encoder_hid_proj.image_projection_layers[0].clip_embeds = clip_embeds.to(dtype=torch.float16)
+pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut = False # True if Plus v2
+```
+
+
### Multi IP-Adapter
More than one IP-Adapter can be used at the same time to generate specific images in more diverse styles. For example, you can use IP-Adapter-Face to generate consistent faces and characters, and IP-Adapter Plus to generate those faces in a specific style.
diff --git a/docs/source/en/using-diffusers/loading.md b/docs/source/en/using-diffusers/loading.md
index 9d5534154fc8..e7b2c4b7acb3 100644
--- a/docs/source/en/using-diffusers/loading.md
+++ b/docs/source/en/using-diffusers/loading.md
@@ -10,57 +10,75 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# Load pipelines, models, and schedulers
+# Load pipelines
[[open-in-colab]]
-Having an easy way to use a diffusion system for inference is essential to 🧨 Diffusers. Diffusion systems often consist of multiple components like parameterized models, tokenizers, and schedulers that interact in complex ways. That is why we designed the [`DiffusionPipeline`] to wrap the complexity of the entire diffusion system into an easy-to-use API, while remaining flexible enough to be adapted for other use cases, such as loading each component individually as building blocks to assemble your own diffusion system.
-
-Everything you need for inference or training is accessible with the `from_pretrained()` method.
+Diffusion systems consist of multiple components like parameterized models and schedulers that interact in complex ways. That is why we designed the [`DiffusionPipeline`] to wrap the complexity of the entire diffusion system into an easy-to-use API. At the same time, the [`DiffusionPipeline`] is entirely customizable so you can modify each component to build a diffusion system for your use case.
This guide will show you how to load:
- pipelines from the Hub and locally
- different components into a pipeline
+- multiple pipelines without increasing memory usage
- checkpoint variants such as different floating point types or non-exponential mean averaged (EMA) weights
-- models and schedulers
-## Diffusion Pipeline
+## Load a pipeline
+
+> [!TIP]
+> Skip to the [DiffusionPipeline explained](#diffusionpipeline-explained) section if you're interested in an explanation about how the [`DiffusionPipeline`] class works.
-
+There are two ways to load a pipeline for a task:
-💡 Skip to the [DiffusionPipeline explained](#diffusionpipeline-explained) section if you are interested in learning in more detail about how the [`DiffusionPipeline`] class works.
+1. Load the generic [`DiffusionPipeline`] class and allow it to automatically detect the correct pipeline class from the checkpoint.
+2. Load a specific pipeline class for a specific task.
-
+
+
-The [`DiffusionPipeline`] class is the simplest and most generic way to load the latest trending diffusion model from the [Hub](https://huggingface.co/models?library=diffusers&sort=trending). The [`DiffusionPipeline.from_pretrained`] method automatically detects the correct pipeline class from the checkpoint, downloads, and caches all the required configuration and weight files, and returns a pipeline instance ready for inference.
+The [`DiffusionPipeline`] class is a simple and generic way to load the latest trending diffusion model from the [Hub](https://huggingface.co/models?library=diffusers&sort=trending). It uses the [`~DiffusionPipeline.from_pretrained`] method to automatically detect the correct pipeline class for a task from the checkpoint, downloads and caches all the required configuration and weight files, and returns a pipeline ready for inference.
```python
from diffusers import DiffusionPipeline
-repo_id = "runwayml/stable-diffusion-v1-5"
-pipe = DiffusionPipeline.from_pretrained(repo_id, use_safetensors=True)
+pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
```
-You can also load a checkpoint with its specific pipeline class. The example above loaded a Stable Diffusion model; to get the same result, use the [`StableDiffusionPipeline`] class:
+This same checkpoint can also be used for an image-to-image task. The [`DiffusionPipeline`] class can handle any task as long as you provide the appropriate inputs. For example, for an image-to-image task, you need to pass an initial image to the pipeline.
+
+```py
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
+
+init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png")
+prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
+image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=init_image).images[0]
+```
+
+
+
+
+Checkpoints can be loaded by their specific pipeline class if you already know it. For example, to load a Stable Diffusion model, use the [`StableDiffusionPipeline`] class.
```python
from diffusers import StableDiffusionPipeline
-repo_id = "runwayml/stable-diffusion-v1-5"
-pipe = StableDiffusionPipeline.from_pretrained(repo_id, use_safetensors=True)
+pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
```
-A checkpoint (such as [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) or [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5)) may also be used for more than one task, like text-to-image or image-to-image. To differentiate what task you want to use the checkpoint for, you have to load it directly with its corresponding task-specific pipeline class:
+This same checkpoint may also be used for another task like image-to-image. To differentiate what task you want to use the checkpoint for, you have to use the corresponding task-specific pipeline class. For example, to use the same checkpoint for image-to-image, use the [`StableDiffusionImg2ImgPipeline`] class.
-```python
+```py
from diffusers import StableDiffusionImg2ImgPipeline
-repo_id = "runwayml/stable-diffusion-v1-5"
-pipe = StableDiffusionImg2ImgPipeline.from_pretrained(repo_id)
+pipeline = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
```
-You can use the Space below to gauge the memory requirements of a pipeline you want to load beforehand without downloading the pipeline checkpoints:
+
+
+
+Use the Space below to gauge a pipeline's memory requirements before you download and load it to see if it runs on your hardware.
-## Weighting
+### Weighting
You'll notice there is no "ball" in the image! Let's use compel to upweight the concept of "ball" in the prompt. Create a [`Compel`](https://github.com/damian0815/compel/blob/main/doc/compel.md#compel-objects) object, and pass it a tokenizer and text encoder:
@@ -123,7 +322,7 @@ image
-## Blending
+### Blending
You can also create a weighted *blend* of prompts by adding `.blend()` to a list of prompts and passing it some weights. Your blend may not always produce the result you expect because it breaks some assumptions about how the text encoder functions, so just have fun and experiment with it!
@@ -139,7 +338,7 @@ image
-## Conjunction
+### Conjunction
A conjunction diffuses each prompt independently and concatenates their results by their weighted sum. Add `.and()` to the end of a list of prompts to create a conjunction:
@@ -155,7 +354,7 @@ image
-## Textual inversion
+### Textual inversion
[Textual inversion](../training/text_inversion) is a technique for learning a specific concept from some images which you can use to generate new images conditioned on that concept.
@@ -195,7 +394,7 @@ image
-## DreamBooth
+### DreamBooth
[DreamBooth](../training/dreambooth) is a technique for generating contextualized images of a subject given just a few images of the subject to train on. It is similar to textual inversion, but DreamBooth trains the full model whereas textual inversion only fine-tunes the text embeddings. This means you should use [`~DiffusionPipeline.from_pretrained`] to load the DreamBooth model (feel free to browse the [Stable Diffusion Dreambooth Concepts Library](https://huggingface.co/sd-dreambooth-library) for 100+ trained models):
@@ -221,7 +420,7 @@ image
-## Stable Diffusion XL
+### Stable Diffusion XL
Stable Diffusion XL (SDXL) has two tokenizers and text encoders so it's usage is a bit different. To address this, you should pass both tokenizers and encoders to the `Compel` class:
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
index 5ce94680aeb2..6cdf2e7b21ab 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
@@ -23,6 +23,7 @@
import re
import shutil
import warnings
+from contextlib import nullcontext
from pathlib import Path
from typing import List, Optional
@@ -1844,7 +1845,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}
- with torch.cuda.amp.autocast():
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
index ff272e3b902e..21a84b77245a 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
import argparse
-import contextlib
import gc
import hashlib
import itertools
@@ -26,6 +25,7 @@
import re
import shutil
import warnings
+from contextlib import nullcontext
from pathlib import Path
from typing import List, Optional
@@ -2192,13 +2192,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}
- inference_ctx = (
- contextlib.nullcontext()
- if "playground" in args.pretrained_model_name_or_path
- else torch.cuda.amp.autocast()
- )
+ if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
- with inference_ctx:
+ with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
diff --git a/examples/amused/train_amused.py b/examples/amused/train_amused.py
index 33673b3f7eb7..3ec0503dfdfe 100644
--- a/examples/amused/train_amused.py
+++ b/examples/amused/train_amused.py
@@ -430,6 +430,9 @@ def main(args):
log_with=args.report_to,
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
diff --git a/examples/community/README.md b/examples/community/README.md
index cc471874ca02..5cebc4f9f049 100755
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -3819,12 +3819,10 @@ export_to_gif(frames, "animation.gif")
IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded.
You need to install `insightface` and all its requirements to use this model.
You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`.
-You have to disable PEFT BACKEND in order to load weights.
You can find more results [here](https://github.com/huggingface/diffusers/pull/6276).
```py
import diffusers
-diffusers.utils.USE_PEFT_BACKEND = False
import torch
from diffusers.utils import load_image
import cv2
diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py
index dbea6836fffc..c2a60bed6426 100644
--- a/examples/community/composable_stable_diffusion.py
+++ b/examples/community/composable_stable_diffusion.py
@@ -321,7 +321,12 @@ def check_inputs(self, prompt, height, width, callback_steps):
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
diff --git a/examples/community/gluegen.py b/examples/community/gluegen.py
index b8f147000229..c656dce55a0d 100644
--- a/examples/community/gluegen.py
+++ b/examples/community/gluegen.py
@@ -500,7 +500,12 @@ def check_inputs(
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/community/instaflow_one_step.py b/examples/community/instaflow_one_step.py
index b07d85f8fcdf..b0476d3afe38 100644
--- a/examples/community/instaflow_one_step.py
+++ b/examples/community/instaflow_one_step.py
@@ -468,7 +468,12 @@ def check_inputs(
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/community/ip_adapter_face_id.py b/examples/community/ip_adapter_face_id.py
index b4d2446b5ce9..bb5a2a4fe5a9 100644
--- a/examples/community/ip_adapter_face_id.py
+++ b/examples/community/ip_adapter_face_id.py
@@ -26,7 +26,14 @@
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
-from diffusers.models.lora import LoRALinearLayer, adjust_lora_scale_text_encoder
+from diffusers.models.attention_processor import (
+ AttnProcessor,
+ AttnProcessor2_0,
+ IPAdapterAttnProcessor,
+ IPAdapterAttnProcessor2_0,
+)
+from diffusers.models.embeddings import MultiIPAdapterImageProjection
+from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -45,300 +52,6 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class LoRAIPAdapterAttnProcessor(nn.Module):
- r"""
- Attention processor for IP-Adapater.
- Args:
- hidden_size (`int`):
- The hidden size of the attention layer.
- cross_attention_dim (`int`):
- The number of channels in the `encoder_hidden_states`.
- rank (`int`, defaults to 4):
- The dimension of the LoRA update matrices.
- network_alpha (`int`, *optional*):
- Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
- lora_scale (`float`, defaults to 1.0):
- the weight scale of LoRA.
- scale (`float`, defaults to 1.0):
- the weight scale of image prompt.
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
- The context length of the image features.
- """
-
- def __init__(
- self,
- hidden_size,
- cross_attention_dim=None,
- rank=4,
- network_alpha=None,
- lora_scale=1.0,
- scale=1.0,
- num_tokens=4,
- ):
- super().__init__()
-
- self.rank = rank
- self.lora_scale = lora_scale
-
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
- self.scale = scale
- self.num_tokens = num_tokens
-
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
-
- def __call__(
- self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- ):
- residual = hidden_states
-
- # separate ip_hidden_states from encoder_hidden_states
- if encoder_hidden_states is not None:
- if isinstance(encoder_hidden_states, tuple):
- encoder_hidden_states, ip_hidden_states = encoder_hidden_states
- else:
- deprecation_message = (
- "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
- " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
- )
- deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
- encoder_hidden_states, ip_hidden_states = (
- encoder_hidden_states[:, :end_pos, :],
- [encoder_hidden_states[:, end_pos:, :]],
- )
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
-
- query = attn.head_to_batch_dim(query)
- key = attn.head_to_batch_dim(key)
- value = attn.head_to_batch_dim(value)
-
- attention_probs = attn.get_attention_scores(query, key, attention_mask)
- hidden_states = torch.bmm(attention_probs, value)
- hidden_states = attn.batch_to_head_dim(hidden_states)
-
- # for ip-adapter
- ip_key = self.to_k_ip(ip_hidden_states)
- ip_value = self.to_v_ip(ip_hidden_states)
-
- ip_key = attn.head_to_batch_dim(ip_key)
- ip_value = attn.head_to_batch_dim(ip_value)
-
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
- ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
- ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
-
- hidden_states = hidden_states + self.scale * ip_hidden_states
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
-class LoRAIPAdapterAttnProcessor2_0(nn.Module):
- r"""
- Attention processor for IP-Adapater for PyTorch 2.0.
- Args:
- hidden_size (`int`):
- The hidden size of the attention layer.
- cross_attention_dim (`int`):
- The number of channels in the `encoder_hidden_states`.
- rank (`int`, defaults to 4):
- The dimension of the LoRA update matrices.
- network_alpha (`int`, *optional*):
- Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
- lora_scale (`float`, defaults to 1.0):
- the weight scale of LoRA.
- scale (`float`, defaults to 1.0):
- the weight scale of image prompt.
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
- The context length of the image features.
- """
-
- def __init__(
- self,
- hidden_size,
- cross_attention_dim=None,
- rank=4,
- network_alpha=None,
- lora_scale=1.0,
- scale=1.0,
- num_tokens=4,
- ):
- super().__init__()
-
- self.rank = rank
- self.lora_scale = lora_scale
-
- self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
- self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
- self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
- self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
- self.scale = scale
- self.num_tokens = num_tokens
-
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
-
- def __call__(
- self,
- attn,
- hidden_states,
- encoder_hidden_states=None,
- attention_mask=None,
- temb=None,
- ):
- residual = hidden_states
-
- # separate ip_hidden_states from encoder_hidden_states
- if encoder_hidden_states is not None:
- if isinstance(encoder_hidden_states, tuple):
- encoder_hidden_states, ip_hidden_states = encoder_hidden_states
- else:
- deprecation_message = (
- "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
- " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
- )
- deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
- end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
- encoder_hidden_states, ip_hidden_states = (
- encoder_hidden_states[:, :end_pos, :],
- [encoder_hidden_states[:, end_pos:, :]],
- )
-
- if attn.spatial_norm is not None:
- hidden_states = attn.spatial_norm(hidden_states, temb)
-
- input_ndim = hidden_states.ndim
-
- if input_ndim == 4:
- batch_size, channel, height, width = hidden_states.shape
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
-
- batch_size, sequence_length, _ = (
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
- )
-
- if attention_mask is not None:
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
- # scaled_dot_product_attention expects attention_mask shape to be
- # (batch, heads, source_length, target_length)
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
-
- if attn.group_norm is not None:
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
-
- query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
-
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- elif attn.norm_cross:
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
-
- key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- # for ip-adapter
- ip_key = self.to_k_ip(ip_hidden_states)
- ip_value = self.to_v_ip(ip_hidden_states)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
-
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- ip_hidden_states = ip_hidden_states.to(query.dtype)
-
- hidden_states = hidden_states + self.scale * ip_hidden_states
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- if input_ndim == 4:
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
-
- if attn.residual_connection:
- hidden_states = hidden_states + residual
-
- hidden_states = hidden_states / attn.rescale_output_factor
-
- return hidden_states
-
-
class IPAdapterFullImageProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
super().__init__()
@@ -615,17 +328,13 @@ def convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
return image_projection
def _load_ip_adapter_weights(self, state_dict):
- from diffusers.models.attention_processor import (
- AttnProcessor,
- AttnProcessor2_0,
- )
-
num_image_text_embeds = 4
self.unet.encoder_hid_proj = None
# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
+ lora_dict = {}
key_id = 0
for name in self.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
@@ -642,94 +351,99 @@ def _load_ip_adapter_weights(self, state_dict):
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
- rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0]
- attn_module = self.unet
- for n in name.split(".")[:-1]:
- attn_module = getattr(attn_module, n)
- # Set the `lora_layer` attribute of the attention-related matrices.
- attn_module.to_q.set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_q.in_features,
- out_features=attn_module.to_q.out_features,
- rank=rank,
- )
+
+ lora_dict.update(
+ {f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}
)
- attn_module.to_k.set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_k.in_features,
- out_features=attn_module.to_k.out_features,
- rank=rank,
- )
+ lora_dict.update(
+ {f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}
)
- attn_module.to_v.set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_v.in_features,
- out_features=attn_module.to_v.out_features,
- rank=rank,
- )
+ lora_dict.update(
+ {f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}
)
- attn_module.to_out[0].set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_out[0].in_features,
- out_features=attn_module.to_out[0].out_features,
- rank=rank,
- )
+ lora_dict.update(
+ {
+ f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
+ f"{key_id}.to_out_lora.down.weight"
+ ]
+ }
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}
)
-
- value_dict = {}
- for k, module in attn_module.named_children():
- index = "."
- if not hasattr(module, "set_lora_layer"):
- index = ".0."
- module = module[0]
- lora_layer = getattr(module, "lora_layer")
- for lora_name, w in lora_layer.state_dict().items():
- value_dict.update(
- {
- f"{k}{index}lora_layer.{lora_name}": state_dict["ip_adapter"][
- f"{key_id}.{k}_lora.{lora_name}"
- ]
- }
- )
-
- attn_module.load_state_dict(value_dict, strict=False)
- attn_module.to(dtype=self.dtype, device=self.device)
key_id += 1
else:
- rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0]
attn_processor_class = (
- LoRAIPAdapterAttnProcessor2_0
- if hasattr(F, "scaled_dot_product_attention")
- else LoRAIPAdapterAttnProcessor
+ IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
- rank=rank,
num_tokens=num_image_text_embeds,
).to(dtype=self.dtype, device=self.device)
- value_dict = {}
- for k, w in attn_procs[name].state_dict().items():
- value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})
+ lora_dict.update(
+ {f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}
+ )
+ lora_dict.update(
+ {
+ f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
+ f"{key_id}.to_out_lora.down.weight"
+ ]
+ }
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
+ )
+ lora_dict.update(
+ {f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}
+ )
+ value_dict = {}
+ value_dict.update({"to_k_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
+ value_dict.update({"to_v_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
attn_procs[name].load_state_dict(value_dict)
key_id += 1
self.unet.set_attn_processor(attn_procs)
+ self.load_lora_weights(lora_dict, adapter_name="faceid")
+ self.set_adapters(["faceid"], adapter_weights=[1.0])
+
# convert IP-Adapter Image Projection layers to diffusers
image_projection = self.convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"])
+ image_projection_layers = [image_projection.to(device=self.device, dtype=self.dtype)]
- self.unet.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
+ self.unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.unet.config.encoder_hid_dim_type = "ip_image_proj"
def set_ip_adapter_scale(self, scale):
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for attn_processor in unet.attn_processors.values():
- if isinstance(attn_processor, (LoRAIPAdapterAttnProcessor, LoRAIPAdapterAttnProcessor2_0)):
- attn_processor.scale = scale
+ if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
+ attn_processor.scale = [scale]
def _encode_prompt(
self,
@@ -1039,7 +753,12 @@ def check_inputs(
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -1298,7 +1017,7 @@ def __call__(
negative_image_embeds = torch.zeros_like(image_embeds)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
-
+ image_embeds = [image_embeds]
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
@@ -1319,7 +1038,7 @@ def __call__(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Add image embeds for IP-Adapter
- added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else None
+ added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else {}
# 6.2 Optionally get Guidance Scale Embedding
timestep_cond = None
diff --git a/examples/community/latent_consistency_img2img.py b/examples/community/latent_consistency_img2img.py
index 125cea8bde88..35cd74166c68 100644
--- a/examples/community/latent_consistency_img2img.py
+++ b/examples/community/latent_consistency_img2img.py
@@ -177,7 +177,12 @@ def prepare_latents(
latents=None,
generator=None,
):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
diff --git a/examples/community/latent_consistency_interpolate.py b/examples/community/latent_consistency_interpolate.py
index a75e80a678ca..3d2413c99189 100644
--- a/examples/community/latent_consistency_interpolate.py
+++ b/examples/community/latent_consistency_interpolate.py
@@ -472,7 +472,12 @@ def run_safety_checker(self, image, device, dtype):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/community/latent_consistency_txt2img.py b/examples/community/latent_consistency_txt2img.py
index 0f2acbf79637..c31d6abae368 100755
--- a/examples/community/latent_consistency_txt2img.py
+++ b/examples/community/latent_consistency_txt2img.py
@@ -163,7 +163,12 @@ def run_safety_checker(self, image, device, dtype):
return image, has_nsfw_concept
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if latents is None:
latents = torch.randn(shape, dtype=dtype).to(device)
else:
diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py
index 1128e365bba1..d1c8b357e16f 100644
--- a/examples/community/lpw_stable_diffusion.py
+++ b/examples/community/lpw_stable_diffusion.py
@@ -729,7 +729,12 @@ def prepare_latents(
):
if image is None:
batch_size = batch_size * num_images_per_prompt
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py
index af25538cf1cb..64b7973e894b 100644
--- a/examples/community/lpw_stable_diffusion_xl.py
+++ b/examples/community/lpw_stable_diffusion_xl.py
@@ -1060,7 +1060,12 @@ def prepare_latents(
batch_size *= num_images_per_prompt
if image is None:
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -1140,7 +1145,12 @@ def prepare_latents(
return latents
else:
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/community/pipeline_animatediff_controlnet.py b/examples/community/pipeline_animatediff_controlnet.py
index eeb8b1547533..4b1a64142704 100644
--- a/examples/community/pipeline_animatediff_controlnet.py
+++ b/examples/community/pipeline_animatediff_controlnet.py
@@ -373,18 +373,29 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
- def encode_image(self, image, device, num_images_per_prompt):
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
- image_embeds = self.image_encoder(image).image_embeds
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
- uncond_image_embeds = torch.zeros_like(image_embeds)
- return image_embeds, uncond_image_embeds
+ return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
diff --git a/examples/community/pipeline_demofusion_sdxl.py b/examples/community/pipeline_demofusion_sdxl.py
index 6818364b5cf0..93e1463638f0 100644
--- a/examples/community/pipeline_demofusion_sdxl.py
+++ b/examples/community/pipeline_demofusion_sdxl.py
@@ -477,7 +477,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/community/pipeline_sdxl_style_aligned.py b/examples/community/pipeline_sdxl_style_aligned.py
index d8ad0dc906eb..88edeeb7ee4c 100644
--- a/examples/community/pipeline_sdxl_style_aligned.py
+++ b/examples/community/pipeline_sdxl_style_aligned.py
@@ -151,7 +151,7 @@ def concat_first(feat: torch.Tensor, dim: int = 2, scale: float = 1.0) -> torch.
return torch.cat((feat, feat_style), dim=dim)
-def calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor]:
+def calc_mean_std(feat: torch.Tensor, eps: float = 1e-5) -> Tuple[torch.Tensor, torch.Tensor]:
feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
feat_mean = feat.mean(dim=-2, keepdims=True)
return feat_mean, feat_std
@@ -919,7 +919,12 @@ def prepare_latents(
batch_size *= num_images_per_prompt
if image is None:
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -999,7 +1004,12 @@ def prepare_latents(
return latents
else:
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/community/pipeline_stable_diffusion_pag.py b/examples/community/pipeline_stable_diffusion_pag.py
index 02dd9a69f473..04f38a888460 100644
--- a/examples/community/pipeline_stable_diffusion_pag.py
+++ b/examples/community/pipeline_stable_diffusion_pag.py
@@ -857,7 +857,12 @@ def check_inputs(
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
index fe94646a4436..82c522b4489a 100644
--- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
+++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
@@ -751,7 +751,12 @@ def check_conditions(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
index de7865d654b0..a85f1c3da6fb 100644
--- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
+++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
@@ -17,7 +17,7 @@
import inspect
from collections.abc import Callable
-from typing import Any, List, Optional, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL
@@ -1211,8 +1211,8 @@ def prepare_control_image(
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
- prompt: Optional[Union[str, list[str]]] = None,
- prompt_2: Optional[Union[str, list[str]]] = None,
+ prompt: Optional[Union[str, List[str]]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None,
mask_image: Optional[Union[torch.Tensor, PIL.Image.Image]] = None,
adapter_image: PipelineImageInput = None,
@@ -1224,11 +1224,11 @@ def __call__(
denoising_start: Optional[float] = None,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
- negative_prompt: Optional[Union[str, list[str]]] = None,
- negative_prompt_2: Optional[Union[str, list[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
- generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[Union[torch.FloatTensor]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -1238,12 +1238,12 @@ def __call__(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
- cross_attention_kwargs: Optional[dict[str, Any]] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
- original_size: Optional[tuple[int, int]] = None,
- crops_coords_top_left: Optional[tuple[int, int]] = (0, 0),
- target_size: Optional[tuple[int, int]] = None,
- adapter_conditioning_scale: Optional[Union[float, list[float]]] = 1.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Optional[Tuple[int, int]] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ adapter_conditioning_scale: Optional[Union[float, List[float]]] = 1.0,
cond_tau: float = 1.0,
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
diff --git a/examples/community/pipeline_stable_diffusion_xl_ipex.py b/examples/community/pipeline_stable_diffusion_xl_ipex.py
index 68ad5dbec77d..a44ccf89eadd 100644
--- a/examples/community/pipeline_stable_diffusion_xl_ipex.py
+++ b/examples/community/pipeline_stable_diffusion_xl_ipex.py
@@ -614,7 +614,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/community/pipeline_zero1to3.py b/examples/community/pipeline_zero1to3.py
index 133aa694c18c..5e02ba286679 100644
--- a/examples/community/pipeline_zero1to3.py
+++ b/examples/community/pipeline_zero1to3.py
@@ -497,7 +497,12 @@ def check_inputs(self, image, height, width, callback_steps):
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py
index 0173ed41bee6..a13497dddcaf 100644
--- a/examples/community/stable_diffusion_controlnet_inpaint.py
+++ b/examples/community/stable_diffusion_controlnet_inpaint.py
@@ -635,7 +635,12 @@ def check_inputs(
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py
index 3b5ed09aa168..dd648fd8c708 100644
--- a/examples/community/stable_diffusion_ipex.py
+++ b/examples/community/stable_diffusion_ipex.py
@@ -533,7 +533,12 @@ def check_inputs(
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py
index 7af404c25b41..c2dd184c2fa4 100644
--- a/examples/community/stable_diffusion_reference.py
+++ b/examples/community/stable_diffusion_reference.py
@@ -609,7 +609,12 @@ def prepare_latents(
Returns:
torch.Tensor: The prepared latent vectors.
"""
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
index 46470be865cd..1e88cb67ee71 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
@@ -23,6 +23,7 @@
import os
import random
import shutil
+from contextlib import nullcontext
from pathlib import Path
from typing import List, Union
@@ -238,6 +239,10 @@ def train_dataloader(self):
def log_validation(vae, unet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionPipeline.from_pretrained(
@@ -274,7 +279,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
for _, prompt in enumerate(validation_prompts):
images = []
- with torch.autocast("cuda", dtype=weight_dtype):
+ with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
@@ -1172,6 +1177,11 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
).input_ids.to(accelerator.device)
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
# 16. Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -1300,7 +1310,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
- with torch.autocast("cuda"):
+ with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
@@ -1359,7 +1369,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad():
- with torch.autocast("cuda", dtype=weight_dtype):
+ with autocast_ctx:
target_noise_pred = unet(
x_prev.float(),
timesteps,
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
index a4052324c128..9405c238f937 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
@@ -22,6 +22,7 @@
import os
import random
import shutil
+from contextlib import nullcontext
from pathlib import Path
import accelerate
@@ -146,7 +147,12 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
for _, prompt in enumerate(validation_prompts):
images = []
- with torch.autocast("cuda", dtype=weight_dtype):
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
+
+ with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
index fc4da48fbc4c..08d6b23d6deb 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
@@ -24,6 +24,7 @@
import os
import random
import shutil
+from contextlib import nullcontext
from pathlib import Path
from typing import List, Union
@@ -256,6 +257,10 @@ def train_dataloader(self):
def log_validation(vae, unet, args, accelerator, weight_dtype, step):
logger.info("Running validation... ")
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
unet = accelerator.unwrap_model(unet)
pipeline = StableDiffusionXLPipeline.from_pretrained(
@@ -291,7 +296,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
for _, prompt in enumerate(validation_prompts):
images = []
- with torch.autocast("cuda", dtype=weight_dtype):
+ with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
@@ -1353,7 +1358,12 @@ def compute_embeddings(
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
- with torch.autocast("cuda"):
+ if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
@@ -1416,7 +1426,12 @@ def compute_embeddings(
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
with torch.no_grad():
- with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
+
+ with autocast_ctx:
target_noise_pred = unet(
x_prev.float(),
timesteps,
diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py
index 8908593b16d3..5dcad9f6cc39 100644
--- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py
@@ -23,6 +23,7 @@
import os
import random
import shutil
+from contextlib import nullcontext
from pathlib import Path
from typing import List, Union
@@ -252,7 +253,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
for _, prompt in enumerate(validation_prompts):
images = []
- with torch.autocast("cuda"):
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
@@ -939,7 +945,7 @@ def main(args):
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
# Initialize from (online) unet
- target_unet = UNet2DConditionModel(**teacher_unet.config)
+ target_unet = UNet2DConditionModel.from_config(unet.config)
target_unet.load_state_dict(unet.state_dict())
target_unet.train()
target_unet.requires_grad_(False)
@@ -1257,7 +1263,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
- with torch.autocast("cuda"):
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
@@ -1315,7 +1326,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad():
- with torch.autocast("cuda", dtype=weight_dtype):
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
+
+ with autocast_ctx:
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
index 74d1c007f7f3..a7deca72a86f 100644
--- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
@@ -24,6 +24,7 @@
import os
import random
import shutil
+from contextlib import nullcontext
from pathlib import Path
from typing import List, Union
@@ -270,7 +271,12 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
for _, prompt in enumerate(validation_prompts):
images = []
- with torch.autocast("cuda"):
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
images = pipeline(
prompt=prompt,
num_inference_steps=4,
@@ -998,7 +1004,7 @@ def main(args):
# 8. Create target student U-Net. This will be updated via EMA updates (polyak averaging).
# Initialize from (online) unet
- target_unet = UNet2DConditionModel(**teacher_unet.config)
+ target_unet = UNet2DConditionModel.from_config(unet.config)
target_unet.load_state_dict(unet.state_dict())
target_unet.train()
target_unet.requires_grad_(False)
@@ -1355,7 +1361,12 @@ def compute_embeddings(
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
# solver timestep.
with torch.no_grad():
- with torch.autocast("cuda"):
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = teacher_unet(
noisy_model_input.to(weight_dtype),
@@ -1417,7 +1428,12 @@ def compute_embeddings(
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
with torch.no_grad():
- with torch.autocast("cuda", dtype=weight_dtype):
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
+
+ with autocast_ctx:
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py
index a56e92de2661..3daca0e3f56b 100644
--- a/examples/controlnet/train_controlnet.py
+++ b/examples/controlnet/train_controlnet.py
@@ -752,6 +752,10 @@ def main(args):
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py
index b60280523589..62192521a323 100644
--- a/examples/controlnet/train_controlnet_sdxl.py
+++ b/examples/controlnet/train_controlnet_sdxl.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
import argparse
-import contextlib
import functools
import gc
import logging
@@ -22,6 +21,7 @@
import os
import random
import shutil
+from contextlib import nullcontext
from pathlib import Path
import accelerate
@@ -125,11 +125,10 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
)
image_logs = []
- inference_ctx = (
- contextlib.nullcontext()
- if (is_final_validation or torch.backends.mps.is_available())
- else torch.autocast("cuda")
- )
+ if is_final_validation or torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
@@ -138,7 +137,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
images = []
for _ in range(args.num_validation_images):
- with inference_ctx:
+ with autocast_ctx:
image = pipeline(
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
).images[0]
@@ -811,6 +810,10 @@ def main(args):
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py
index 35d3a59c7231..6858fed8b994 100644
--- a/examples/custom_diffusion/train_custom_diffusion.py
+++ b/examples/custom_diffusion/train_custom_diffusion.py
@@ -676,6 +676,10 @@ def main(args):
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py
index c40758eebfe9..a18c443e7d4d 100644
--- a/examples/dreambooth/train_dreambooth.py
+++ b/examples/dreambooth/train_dreambooth.py
@@ -821,6 +821,10 @@ def main(args):
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py
index 9b43b30e0fe1..0d33a0558989 100644
--- a/examples/dreambooth/train_dreambooth_lora.py
+++ b/examples/dreambooth/train_dreambooth_lora.py
@@ -749,6 +749,10 @@ def main(args):
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index 1da83ff731ad..f3e347cd6ac9 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -23,6 +23,7 @@
import random
import shutil
import warnings
+from contextlib import nullcontext
from pathlib import Path
import numpy as np
@@ -207,18 +208,12 @@ def log_validation(
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
- enable_autocast = True
- if torch.backends.mps.is_available() or (
- accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
- ):
- enable_autocast = False
- if "playground" in args.pretrained_model_name_or_path:
- enable_autocast = False
+ if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
- with torch.autocast(
- accelerator.device.type,
- enabled=enable_autocast,
- ):
+ with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
for tracker in accelerator.trackers:
@@ -992,6 +987,10 @@ def main(args):
kwargs_handlers=[kwargs],
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py
index b40f22df59df..f1125a2919f0 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py
@@ -21,6 +21,7 @@
import math
import os
import shutil
+from contextlib import nullcontext
from pathlib import Path
import accelerate
@@ -52,6 +53,9 @@
from diffusers.utils.torch_utils import is_compiled_module
+if is_wandb_available():
+ import wandb
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.28.0.dev0")
@@ -63,6 +67,48 @@
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ generator,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ original_image = download_image(args.val_image_url)
+ edited_images = []
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ for _ in range(args.num_validation_images):
+ edited_images.append(
+ pipeline(
+ args.validation_prompt,
+ image=original_image,
+ num_inference_steps=20,
+ image_guidance_scale=1.5,
+ guidance_scale=7,
+ generator=generator,
+ ).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "wandb":
+ wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
+ for edited_image in edited_images:
+ wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
+ tracker.log({"validation": wandb_table})
+
+
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
parser.add_argument(
@@ -404,12 +450,11 @@ def main():
project_config=accelerator_project_config,
)
- generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
- if args.report_to == "wandb":
- if not is_wandb_available():
- raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
- import wandb
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
@@ -512,7 +557,8 @@ def save_model_hook(models, weights, output_dir):
model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again
- weights.pop()
+ if weights:
+ weights.pop()
def load_model_hook(models, input_dir):
if args.use_ema:
@@ -918,11 +964,6 @@ def collate_fn(examples):
and (args.validation_prompt is not None)
and (epoch % args.validation_epochs == 0)
):
- logger.info(
- f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
- f" {args.validation_prompt}."
- )
- # create pipeline
if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
ema_unet.store(unet.parameters())
@@ -937,35 +978,14 @@ def collate_fn(examples):
variant=args.variant,
torch_dtype=weight_dtype,
)
- pipeline = pipeline.to(accelerator.device)
- pipeline.set_progress_bar_config(disable=True)
-
- # run inference
- original_image = download_image(args.val_image_url)
- edited_images = []
- with torch.autocast(
- str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
- ):
- for _ in range(args.num_validation_images):
- edited_images.append(
- pipeline(
- args.validation_prompt,
- image=original_image,
- num_inference_steps=20,
- image_guidance_scale=1.5,
- guidance_scale=7,
- generator=generator,
- ).images[0]
- )
-
- for tracker in accelerator.trackers:
- if tracker.name == "wandb":
- wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
- for edited_image in edited_images:
- wandb_table.add_data(
- wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
- )
- tracker.log({"validation": wandb_table})
+
+ log_validation(
+ pipeline,
+ args,
+ accelerator,
+ generator,
+ )
+
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
@@ -976,7 +996,6 @@ def collate_fn(examples):
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
- unet = unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())
@@ -984,7 +1003,7 @@ def collate_fn(examples):
args.pretrained_model_name_or_path,
text_encoder=unwrap_model(text_encoder),
vae=unwrap_model(vae),
- unet=unet,
+ unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
)
@@ -998,31 +1017,13 @@ def collate_fn(examples):
ignore_patterns=["step_*", "epoch_*"],
)
- if args.validation_prompt is not None:
- edited_images = []
- pipeline = pipeline.to(accelerator.device)
- with torch.autocast(str(accelerator.device).replace(":0", "")):
- for _ in range(args.num_validation_images):
- edited_images.append(
- pipeline(
- args.validation_prompt,
- image=original_image,
- num_inference_steps=20,
- image_guidance_scale=1.5,
- guidance_scale=7,
- generator=generator,
- ).images[0]
- )
-
- for tracker in accelerator.trackers:
- if tracker.name == "wandb":
- wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
- for edited_image in edited_images:
- wandb_table.add_data(
- wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
- )
- tracker.log({"test": wandb_table})
-
+ if (args.val_image_url is not None) and (args.validation_prompt is not None):
+ log_validation(
+ pipeline,
+ args,
+ accelerator,
+ generator,
+ )
accelerator.end_training()
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
index aff279963a99..1c0cdf04b2d2 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
@@ -20,6 +20,7 @@
import os
import shutil
import warnings
+from contextlib import nullcontext
from pathlib import Path
from urllib.parse import urlparse
@@ -70,9 +71,7 @@
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
-def log_validation(
- pipeline, args, accelerator, generator, global_step, is_final_validation=False, enable_autocast=True
-):
+def log_validation(pipeline, args, accelerator, generator, global_step, is_final_validation=False):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
@@ -91,7 +90,12 @@ def log_validation(
else Image.open(image_url_or_path).convert("RGB")
)(args.val_image_url_or_path)
- with torch.autocast(accelerator.device.type, enabled=enable_autocast):
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
edited_images = []
# Run inference
for val_img_idx in range(args.num_validation_images):
@@ -507,6 +511,10 @@ def main():
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
# Make one log on every process with the configuration for debugging.
@@ -983,13 +991,6 @@ def collate_fn(examples):
if accelerator.is_main_process:
accelerator.init_trackers("instruct-pix2pix-xl", config=vars(args))
- # Some configurations require autocast to be disabled.
- enable_autocast = True
- if torch.backends.mps.is_available() or (
- accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
- ):
- enable_autocast = False
-
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -1202,7 +1203,6 @@ def collate_fn(examples):
generator,
global_step,
is_final_validation=False,
- enable_autocast=enable_autocast,
)
if args.use_ema:
@@ -1252,7 +1252,6 @@ def collate_fn(examples):
generator,
global_step,
is_final_validation=True,
- enable_autocast=enable_autocast,
)
accelerator.end_training()
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
index a2a13398124a..78f9b7f18b87 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
@@ -458,6 +458,10 @@ def main():
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
index d6fce4937413..eb8ae8cca060 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
@@ -343,6 +343,11 @@ def main():
log_with=args.report_to,
project_config=accelerator_project_config,
)
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
index b19af1f3e341..e169cf92beb9 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
@@ -356,6 +356,11 @@ def main():
log_with=args.report_to,
project_config=accelerator_project_config,
)
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
index bbc0960e0f48..bd95aed2939c 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
@@ -459,6 +459,10 @@ def main():
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/research_projects/controlnet/train_controlnet_webdataset.py b/examples/research_projects/controlnet/train_controlnet_webdataset.py
index 2b397d27d6a2..615eb834ac24 100644
--- a/examples/research_projects/controlnet/train_controlnet_webdataset.py
+++ b/examples/research_projects/controlnet/train_controlnet_webdataset.py
@@ -916,6 +916,10 @@ def main(args):
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/research_projects/controlnetxs/controlnetxs.py b/examples/research_projects/controlnetxs/controlnetxs.py
deleted file mode 100644
index 14ad1d8a3af9..000000000000
--- a/examples/research_projects/controlnetxs/controlnetxs.py
+++ /dev/null
@@ -1,1014 +0,0 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import math
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-import torch
-import torch.utils.checkpoint
-from torch import nn
-from torch.nn import functional as F
-from torch.nn.modules.normalization import GroupNorm
-
-from diffusers.configuration_utils import ConfigMixin, register_to_config
-from diffusers.models.attention_processor import USE_PEFT_BACKEND, AttentionProcessor
-from diffusers.models.autoencoders import AutoencoderKL
-from diffusers.models.lora import LoRACompatibleConv
-from diffusers.models.modeling_utils import ModelMixin
-from diffusers.models.unets.unet_2d_blocks import (
- CrossAttnDownBlock2D,
- CrossAttnUpBlock2D,
- DownBlock2D,
- Downsample2D,
- ResnetBlock2D,
- Transformer2DModel,
- UpBlock2D,
- Upsample2D,
-)
-from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
-from diffusers.utils import BaseOutput, logging
-
-
-logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-
-
-@dataclass
-class ControlNetXSOutput(BaseOutput):
- """
- The output of [`ControlNetXSModel`].
-
- Args:
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
- The output of the `ControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base model
- output, but is already the final output.
- """
-
- sample: torch.FloatTensor = None
-
-
-# copied from diffusers.models.controlnet.ControlNetConditioningEmbedding
-class ControlNetConditioningEmbedding(nn.Module):
- """
- Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
- [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
- training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
- convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
- (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
- model) to encode image-space conditions ... into feature maps ..."
- """
-
- def __init__(
- self,
- conditioning_embedding_channels: int,
- conditioning_channels: int = 3,
- block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
- ):
- super().__init__()
-
- self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
-
- self.blocks = nn.ModuleList([])
-
- for i in range(len(block_out_channels) - 1):
- channel_in = block_out_channels[i]
- channel_out = block_out_channels[i + 1]
- self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
- self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
-
- self.conv_out = zero_module(
- nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
- )
-
- def forward(self, conditioning):
- embedding = self.conv_in(conditioning)
- embedding = F.silu(embedding)
-
- for block in self.blocks:
- embedding = block(embedding)
- embedding = F.silu(embedding)
-
- embedding = self.conv_out(embedding)
-
- return embedding
-
-
-class ControlNetXSModel(ModelMixin, ConfigMixin):
- r"""
- A ControlNet-XS model
-
- This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic
- methods implemented for all models (such as downloading or saving).
-
- Most of parameters for this model are passed into the [`UNet2DConditionModel`] it creates. Check the documentation
- of [`UNet2DConditionModel`] for them.
-
- Parameters:
- conditioning_channels (`int`, defaults to 3):
- Number of channels of conditioning input (e.g. an image)
- controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
- The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
- conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
- The tuple of output channel for each block in the `controlnet_cond_embedding` layer.
- time_embedding_input_dim (`int`, defaults to 320):
- Dimension of input into time embedding. Needs to be same as in the base model.
- time_embedding_dim (`int`, defaults to 1280):
- Dimension of output from time embedding. Needs to be same as in the base model.
- learn_embedding (`bool`, defaults to `False`):
- Whether to use time embedding of the control model. If yes, the time embedding is a linear interpolation of
- the time embeddings of the control and base model with interpolation parameter `time_embedding_mix**3`.
- time_embedding_mix (`float`, defaults to 1.0):
- Linear interpolation parameter used if `learn_embedding` is `True`. A value of 1.0 means only the
- control model's time embedding will be used. A value of 0.0 means only the base model's time embedding will be used.
- base_model_channel_sizes (`Dict[str, List[Tuple[int]]]`):
- Channel sizes of each subblock of base model. Use `gather_subblock_sizes` on your base model to compute it.
- """
-
- @classmethod
- def init_original(cls, base_model: UNet2DConditionModel, is_sdxl=True):
- """
- Create a ControlNetXS model with the same parameters as in the original paper (https://github.com/vislearn/ControlNet-XS).
-
- Parameters:
- base_model (`UNet2DConditionModel`):
- Base UNet model. Needs to be either StableDiffusion or StableDiffusion-XL.
- is_sdxl (`bool`, defaults to `True`):
- Whether passed `base_model` is a StableDiffusion-XL model.
- """
-
- def get_dim_attn_heads(base_model: UNet2DConditionModel, size_ratio: float, num_attn_heads: int):
- """
- Currently, diffusers can only set the dimension of attention heads (see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why).
- The original ControlNet-XS model, however, define the number of attention heads.
- That's why compute the dimensions needed to get the correct number of attention heads.
- """
- block_out_channels = [int(size_ratio * c) for c in base_model.config.block_out_channels]
- dim_attn_heads = [math.ceil(c / num_attn_heads) for c in block_out_channels]
- return dim_attn_heads
-
- if is_sdxl:
- return ControlNetXSModel.from_unet(
- base_model,
- time_embedding_mix=0.95,
- learn_embedding=True,
- size_ratio=0.1,
- conditioning_embedding_out_channels=(16, 32, 96, 256),
- num_attention_heads=get_dim_attn_heads(base_model, 0.1, 64),
- )
- else:
- return ControlNetXSModel.from_unet(
- base_model,
- time_embedding_mix=1.0,
- learn_embedding=True,
- size_ratio=0.0125,
- conditioning_embedding_out_channels=(16, 32, 96, 256),
- num_attention_heads=get_dim_attn_heads(base_model, 0.0125, 8),
- )
-
- @classmethod
- def _gather_subblock_sizes(cls, unet: UNet2DConditionModel, base_or_control: str):
- """To create correctly sized connections between base and control model, we need to know
- the input and output channels of each subblock.
-
- Parameters:
- unet (`UNet2DConditionModel`):
- Unet of which the subblock channels sizes are to be gathered.
- base_or_control (`str`):
- Needs to be either "base" or "control". If "base", decoder is also considered.
- """
- if base_or_control not in ["base", "control"]:
- raise ValueError("`base_or_control` needs to be either `base` or `control`")
-
- channel_sizes = {"down": [], "mid": [], "up": []}
-
- # input convolution
- channel_sizes["down"].append((unet.conv_in.in_channels, unet.conv_in.out_channels))
-
- # encoder blocks
- for module in unet.down_blocks:
- if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
- for r in module.resnets:
- channel_sizes["down"].append((r.in_channels, r.out_channels))
- if module.downsamplers:
- channel_sizes["down"].append(
- (module.downsamplers[0].channels, module.downsamplers[0].out_channels)
- )
- else:
- raise ValueError(f"Encountered unknown module of type {type(module)} while creating ControlNet-XS.")
-
- # middle block
- channel_sizes["mid"].append((unet.mid_block.resnets[0].in_channels, unet.mid_block.resnets[0].out_channels))
-
- # decoder blocks
- if base_or_control == "base":
- for module in unet.up_blocks:
- if isinstance(module, (CrossAttnUpBlock2D, UpBlock2D)):
- for r in module.resnets:
- channel_sizes["up"].append((r.in_channels, r.out_channels))
- else:
- raise ValueError(
- f"Encountered unknown module of type {type(module)} while creating ControlNet-XS."
- )
-
- return channel_sizes
-
- @register_to_config
- def __init__(
- self,
- conditioning_channels: int = 3,
- conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
- controlnet_conditioning_channel_order: str = "rgb",
- time_embedding_input_dim: int = 320,
- time_embedding_dim: int = 1280,
- time_embedding_mix: float = 1.0,
- learn_embedding: bool = False,
- base_model_channel_sizes: Dict[str, List[Tuple[int]]] = {
- "down": [
- (4, 320),
- (320, 320),
- (320, 320),
- (320, 320),
- (320, 640),
- (640, 640),
- (640, 640),
- (640, 1280),
- (1280, 1280),
- ],
- "mid": [(1280, 1280)],
- "up": [
- (2560, 1280),
- (2560, 1280),
- (1920, 1280),
- (1920, 640),
- (1280, 640),
- (960, 640),
- (960, 320),
- (640, 320),
- (640, 320),
- ],
- },
- sample_size: Optional[int] = None,
- down_block_types: Tuple[str] = (
- "CrossAttnDownBlock2D",
- "CrossAttnDownBlock2D",
- "CrossAttnDownBlock2D",
- "DownBlock2D",
- ),
- up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
- norm_num_groups: Optional[int] = 32,
- cross_attention_dim: Union[int, Tuple[int]] = 1280,
- transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
- num_attention_heads: Optional[Union[int, Tuple[int]]] = 8,
- upcast_attention: bool = False,
- ):
- super().__init__()
-
- # 1 - Create control unet
- self.control_model = UNet2DConditionModel(
- sample_size=sample_size,
- down_block_types=down_block_types,
- up_block_types=up_block_types,
- block_out_channels=block_out_channels,
- norm_num_groups=norm_num_groups,
- cross_attention_dim=cross_attention_dim,
- transformer_layers_per_block=transformer_layers_per_block,
- attention_head_dim=num_attention_heads,
- use_linear_projection=True,
- upcast_attention=upcast_attention,
- time_embedding_dim=time_embedding_dim,
- )
-
- # 2 - Do model surgery on control model
- # 2.1 - Allow to use the same time information as the base model
- adjust_time_dims(self.control_model, time_embedding_input_dim, time_embedding_dim)
-
- # 2.2 - Allow for information infusion from base model
-
- # We concat the output of each base encoder subblocks to the input of the next control encoder subblock
- # (We ignore the 1st element, as it represents the `conv_in`.)
- extra_input_channels = [input_channels for input_channels, _ in base_model_channel_sizes["down"][1:]]
- it_extra_input_channels = iter(extra_input_channels)
-
- for b, block in enumerate(self.control_model.down_blocks):
- for r in range(len(block.resnets)):
- increase_block_input_in_encoder_resnet(
- self.control_model, block_no=b, resnet_idx=r, by=next(it_extra_input_channels)
- )
-
- if block.downsamplers:
- increase_block_input_in_encoder_downsampler(
- self.control_model, block_no=b, by=next(it_extra_input_channels)
- )
-
- increase_block_input_in_mid_resnet(self.control_model, by=extra_input_channels[-1])
-
- # 2.3 - Make group norms work with modified channel sizes
- adjust_group_norms(self.control_model)
-
- # 3 - Gather Channel Sizes
- self.ch_inout_ctrl = ControlNetXSModel._gather_subblock_sizes(self.control_model, base_or_control="control")
- self.ch_inout_base = base_model_channel_sizes
-
- # 4 - Build connections between base and control model
- self.down_zero_convs_out = nn.ModuleList([])
- self.down_zero_convs_in = nn.ModuleList([])
- self.middle_block_out = nn.ModuleList([])
- self.middle_block_in = nn.ModuleList([])
- self.up_zero_convs_out = nn.ModuleList([])
- self.up_zero_convs_in = nn.ModuleList([])
-
- for ch_io_base in self.ch_inout_base["down"]:
- self.down_zero_convs_in.append(self._make_zero_conv(in_channels=ch_io_base[1], out_channels=ch_io_base[1]))
- for i in range(len(self.ch_inout_ctrl["down"])):
- self.down_zero_convs_out.append(
- self._make_zero_conv(self.ch_inout_ctrl["down"][i][1], self.ch_inout_base["down"][i][1])
- )
-
- self.middle_block_out = self._make_zero_conv(
- self.ch_inout_ctrl["mid"][-1][1], self.ch_inout_base["mid"][-1][1]
- )
-
- self.up_zero_convs_out.append(
- self._make_zero_conv(self.ch_inout_ctrl["down"][-1][1], self.ch_inout_base["mid"][-1][1])
- )
- for i in range(1, len(self.ch_inout_ctrl["down"])):
- self.up_zero_convs_out.append(
- self._make_zero_conv(self.ch_inout_ctrl["down"][-(i + 1)][1], self.ch_inout_base["up"][i - 1][1])
- )
-
- # 5 - Create conditioning hint embedding
- self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
- conditioning_embedding_channels=block_out_channels[0],
- block_out_channels=conditioning_embedding_out_channels,
- conditioning_channels=conditioning_channels,
- )
-
- # In the mininal implementation setting, we only need the control model up to the mid block
- del self.control_model.up_blocks
- del self.control_model.conv_norm_out
- del self.control_model.conv_out
-
- @classmethod
- def from_unet(
- cls,
- unet: UNet2DConditionModel,
- conditioning_channels: int = 3,
- conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
- controlnet_conditioning_channel_order: str = "rgb",
- learn_embedding: bool = False,
- time_embedding_mix: float = 1.0,
- block_out_channels: Optional[Tuple[int]] = None,
- size_ratio: Optional[float] = None,
- num_attention_heads: Optional[Union[int, Tuple[int]]] = 8,
- norm_num_groups: Optional[int] = None,
- ):
- r"""
- Instantiate a [`ControlNetXSModel`] from [`UNet2DConditionModel`].
-
- Parameters:
- unet (`UNet2DConditionModel`):
- The UNet model we want to control. The dimensions of the ControlNetXSModel will be adapted to it.
- conditioning_channels (`int`, defaults to 3):
- Number of channels of conditioning input (e.g. an image)
- conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
- The tuple of output channel for each block in the `controlnet_cond_embedding` layer.
- controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
- The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
- learn_embedding (`bool`, defaults to `False`):
- Wether to use time embedding of the control model. If yes, the time embedding is a linear interpolation
- of the time embeddings of the control and base model with interpolation parameter
- `time_embedding_mix**3`.
- time_embedding_mix (`float`, defaults to 1.0):
- Linear interpolation parameter used if `learn_embedding` is `True`.
- block_out_channels (`Tuple[int]`, *optional*):
- Down blocks output channels in control model. Either this or `size_ratio` must be given.
- size_ratio (float, *optional*):
- When given, block_out_channels is set to a relative fraction of the base model's block_out_channels.
- Either this or `block_out_channels` must be given.
- num_attention_heads (`Union[int, Tuple[int]]`, *optional*):
- The dimension of the attention heads. The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
- norm_num_groups (int, *optional*, defaults to `None`):
- The number of groups to use for the normalization of the control unet. If `None`,
- `int(unet.config.norm_num_groups * size_ratio)` is taken.
- """
-
- # Check input
- fixed_size = block_out_channels is not None
- relative_size = size_ratio is not None
- if not (fixed_size ^ relative_size):
- raise ValueError(
- "Pass exactly one of `block_out_channels` (for absolute sizing) or `control_model_ratio` (for relative sizing)."
- )
-
- # Create model
- if block_out_channels is None:
- block_out_channels = [int(size_ratio * c) for c in unet.config.block_out_channels]
-
- # Check that attention heads and group norms match channel sizes
- # - attention heads
- def attn_heads_match_channel_sizes(attn_heads, channel_sizes):
- if isinstance(attn_heads, (tuple, list)):
- return all(c % a == 0 for a, c in zip(attn_heads, channel_sizes))
- else:
- return all(c % attn_heads == 0 for c in channel_sizes)
-
- num_attention_heads = num_attention_heads or unet.config.attention_head_dim
- if not attn_heads_match_channel_sizes(num_attention_heads, block_out_channels):
- raise ValueError(
- f"The dimension of attention heads ({num_attention_heads}) must divide `block_out_channels` ({block_out_channels}). If you didn't set `num_attention_heads` the default settings don't match your model. Set `num_attention_heads` manually."
- )
-
- # - group norms
- def group_norms_match_channel_sizes(num_groups, channel_sizes):
- return all(c % num_groups == 0 for c in channel_sizes)
-
- if norm_num_groups is None:
- if group_norms_match_channel_sizes(unet.config.norm_num_groups, block_out_channels):
- norm_num_groups = unet.config.norm_num_groups
- else:
- norm_num_groups = min(block_out_channels)
-
- if group_norms_match_channel_sizes(norm_num_groups, block_out_channels):
- print(
- f"`norm_num_groups` was set to `min(block_out_channels)` (={norm_num_groups}) so it divides all block_out_channels` ({block_out_channels}). Set it explicitly to remove this information."
- )
- else:
- raise ValueError(
- f"`block_out_channels` ({block_out_channels}) don't match the base models `norm_num_groups` ({unet.config.norm_num_groups}). Setting `norm_num_groups` to `min(block_out_channels)` ({norm_num_groups}) didn't fix this. Pass `norm_num_groups` explicitly so it divides all block_out_channels."
- )
-
- def get_time_emb_input_dim(unet: UNet2DConditionModel):
- return unet.time_embedding.linear_1.in_features
-
- def get_time_emb_dim(unet: UNet2DConditionModel):
- return unet.time_embedding.linear_2.out_features
-
- # Clone params from base unet if
- # (i) it's required to build SD or SDXL, and
- # (ii) it's not used for the time embedding (as time embedding of control model is never used), and
- # (iii) it's not set further below anyway
- to_keep = [
- "cross_attention_dim",
- "down_block_types",
- "sample_size",
- "transformer_layers_per_block",
- "up_block_types",
- "upcast_attention",
- ]
- kwargs = {k: v for k, v in dict(unet.config).items() if k in to_keep}
- kwargs.update(block_out_channels=block_out_channels)
- kwargs.update(num_attention_heads=num_attention_heads)
- kwargs.update(norm_num_groups=norm_num_groups)
-
- # Add controlnetxs-specific params
- kwargs.update(
- conditioning_channels=conditioning_channels,
- controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
- time_embedding_input_dim=get_time_emb_input_dim(unet),
- time_embedding_dim=get_time_emb_dim(unet),
- time_embedding_mix=time_embedding_mix,
- learn_embedding=learn_embedding,
- base_model_channel_sizes=ControlNetXSModel._gather_subblock_sizes(unet, base_or_control="base"),
- conditioning_embedding_out_channels=conditioning_embedding_out_channels,
- )
-
- return cls(**kwargs)
-
- @property
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- return self.control_model.attn_processors
-
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- self.control_model.set_attn_processor(processor)
-
- def set_default_attn_processor(self):
- """
- Disables custom attention processors and sets the default attention implementation.
- """
- self.control_model.set_default_attn_processor()
-
- def set_attention_slice(self, slice_size):
- r"""
- Enable sliced attention computation.
-
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
-
- Args:
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
- must be a multiple of `slice_size`.
- """
- self.control_model.set_attention_slice(slice_size)
-
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, (UNet2DConditionModel)):
- if value:
- module.enable_gradient_checkpointing()
- else:
- module.disable_gradient_checkpointing()
-
- def forward(
- self,
- base_model: UNet2DConditionModel,
- sample: torch.FloatTensor,
- timestep: Union[torch.Tensor, float, int],
- encoder_hidden_states: torch.Tensor,
- controlnet_cond: torch.Tensor,
- conditioning_scale: float = 1.0,
- class_labels: Optional[torch.Tensor] = None,
- timestep_cond: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
- return_dict: bool = True,
- ) -> Union[ControlNetXSOutput, Tuple]:
- """
- The [`ControlNetModel`] forward method.
-
- Args:
- base_model (`UNet2DConditionModel`):
- The base unet model we want to control.
- sample (`torch.FloatTensor`):
- The noisy input tensor.
- timestep (`Union[torch.Tensor, float, int]`):
- The number of timesteps to denoise an input.
- encoder_hidden_states (`torch.Tensor`):
- The encoder hidden states.
- controlnet_cond (`torch.FloatTensor`):
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
- conditioning_scale (`float`, defaults to `1.0`):
- How much the control model affects the base model outputs.
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
- timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
- Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
- timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
- embeddings.
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
- negative values to the attention scores corresponding to "discard" tokens.
- added_cond_kwargs (`dict`):
- Additional conditions for the Stable Diffusion XL UNet.
- cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
- A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
- return_dict (`bool`, defaults to `True`):
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
-
- Returns:
- [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`:
- If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a
- tuple is returned where the first element is the sample tensor.
- """
- # check channel order
- channel_order = self.config.controlnet_conditioning_channel_order
-
- if channel_order == "rgb":
- # in rgb order by default
- ...
- elif channel_order == "bgr":
- controlnet_cond = torch.flip(controlnet_cond, dims=[1])
- else:
- raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
-
- # scale control strength
- n_connections = len(self.down_zero_convs_out) + 1 + len(self.up_zero_convs_out)
- scale_list = torch.full((n_connections,), conditioning_scale)
-
- # prepare attention_mask
- if attention_mask is not None:
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
- attention_mask = attention_mask.unsqueeze(1)
-
- # 1. time
- timesteps = timestep
- if not torch.is_tensor(timesteps):
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
- # This would be a good case for the `match` statement (Python 3.10+)
- is_mps = sample.device.type == "mps"
- if isinstance(timestep, float):
- dtype = torch.float32 if is_mps else torch.float64
- else:
- dtype = torch.int32 if is_mps else torch.int64
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
- elif len(timesteps.shape) == 0:
- timesteps = timesteps[None].to(sample.device)
-
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timesteps = timesteps.expand(sample.shape[0])
-
- t_emb = base_model.time_proj(timesteps)
-
- # timesteps does not contain any weights and will always return f32 tensors
- # but time_embedding might actually be running in fp16. so we need to cast here.
- # there might be better ways to encapsulate this.
- t_emb = t_emb.to(dtype=sample.dtype)
-
- if self.config.learn_embedding:
- ctrl_temb = self.control_model.time_embedding(t_emb, timestep_cond)
- base_temb = base_model.time_embedding(t_emb, timestep_cond)
- interpolation_param = self.config.time_embedding_mix**0.3
-
- temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param)
- else:
- temb = base_model.time_embedding(t_emb)
-
- # added time & text embeddings
- aug_emb = None
-
- if base_model.class_embedding is not None:
- if class_labels is None:
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
-
- if base_model.config.class_embed_type == "timestep":
- class_labels = base_model.time_proj(class_labels)
-
- class_emb = base_model.class_embedding(class_labels).to(dtype=self.dtype)
- temb = temb + class_emb
-
- if base_model.config.addition_embed_type is not None:
- if base_model.config.addition_embed_type == "text":
- aug_emb = base_model.add_embedding(encoder_hidden_states)
- elif base_model.config.addition_embed_type == "text_image":
- raise NotImplementedError()
- elif base_model.config.addition_embed_type == "text_time":
- # SDXL - style
- if "text_embeds" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
- )
- text_embeds = added_cond_kwargs.get("text_embeds")
- if "time_ids" not in added_cond_kwargs:
- raise ValueError(
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
- )
- time_ids = added_cond_kwargs.get("time_ids")
- time_embeds = base_model.add_time_proj(time_ids.flatten())
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
- add_embeds = add_embeds.to(temb.dtype)
- aug_emb = base_model.add_embedding(add_embeds)
- elif base_model.config.addition_embed_type == "image":
- raise NotImplementedError()
- elif base_model.config.addition_embed_type == "image_hint":
- raise NotImplementedError()
-
- temb = temb + aug_emb if aug_emb is not None else temb
-
- # text embeddings
- cemb = encoder_hidden_states
-
- # Preparation
- guided_hint = self.controlnet_cond_embedding(controlnet_cond)
-
- h_ctrl = h_base = sample
- hs_base, hs_ctrl = [], []
- it_down_convs_in, it_down_convs_out, it_dec_convs_in, it_up_convs_out = map(
- iter, (self.down_zero_convs_in, self.down_zero_convs_out, self.up_zero_convs_in, self.up_zero_convs_out)
- )
- scales = iter(scale_list)
-
- base_down_subblocks = to_sub_blocks(base_model.down_blocks)
- ctrl_down_subblocks = to_sub_blocks(self.control_model.down_blocks)
- base_mid_subblocks = to_sub_blocks([base_model.mid_block])
- ctrl_mid_subblocks = to_sub_blocks([self.control_model.mid_block])
- base_up_subblocks = to_sub_blocks(base_model.up_blocks)
-
- # Cross Control
- # 0 - conv in
- h_base = base_model.conv_in(h_base)
- h_ctrl = self.control_model.conv_in(h_ctrl)
- if guided_hint is not None:
- h_ctrl += guided_hint
- h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base
-
- hs_base.append(h_base)
- hs_ctrl.append(h_ctrl)
-
- # 1 - down
- for m_base, m_ctrl in zip(base_down_subblocks, ctrl_down_subblocks):
- h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl
- h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock
- h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock
- h_base = h_base + next(it_down_convs_out)(h_ctrl) * next(scales) # D - add ctrl -> base
- hs_base.append(h_base)
- hs_ctrl.append(h_ctrl)
-
- # 2 - mid
- h_ctrl = torch.cat([h_ctrl, next(it_down_convs_in)(h_base)], dim=1) # A - concat base -> ctrl
- for m_base, m_ctrl in zip(base_mid_subblocks, ctrl_mid_subblocks):
- h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs) # B - apply base subblock
- h_ctrl = m_ctrl(h_ctrl, temb, cemb, attention_mask, cross_attention_kwargs) # C - apply ctrl subblock
- h_base = h_base + self.middle_block_out(h_ctrl) * next(scales) # D - add ctrl -> base
-
- # 3 - up
- for i, m_base in enumerate(base_up_subblocks):
- h_base = h_base + next(it_up_convs_out)(hs_ctrl.pop()) * next(scales) # add info from ctrl encoder
- h_base = torch.cat([h_base, hs_base.pop()], dim=1) # concat info from base encoder+ctrl encoder
- h_base = m_base(h_base, temb, cemb, attention_mask, cross_attention_kwargs)
-
- h_base = base_model.conv_norm_out(h_base)
- h_base = base_model.conv_act(h_base)
- h_base = base_model.conv_out(h_base)
-
- if not return_dict:
- return h_base
-
- return ControlNetXSOutput(sample=h_base)
-
- def _make_zero_conv(self, in_channels, out_channels=None):
- # keep running track of channels sizes
- self.in_channels = in_channels
- self.out_channels = out_channels or in_channels
-
- return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))
-
- @torch.no_grad()
- def _check_if_vae_compatible(self, vae: AutoencoderKL):
- condition_downscale_factor = 2 ** (len(self.config.conditioning_embedding_out_channels) - 1)
- vae_downscale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
- compatible = condition_downscale_factor == vae_downscale_factor
- return compatible, condition_downscale_factor, vae_downscale_factor
-
-
-class SubBlock(nn.ModuleList):
- """A SubBlock is the largest piece of either base or control model, that is executed independently of the other model respectively.
- Before each subblock, information is concatted from base to control. And after each subblock, information is added from control to base.
- """
-
- def __init__(self, ms, *args, **kwargs):
- if not is_iterable(ms):
- ms = [ms]
- super().__init__(ms, *args, **kwargs)
-
- def forward(
- self,
- x: torch.Tensor,
- temb: torch.Tensor,
- cemb: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- ):
- """Iterate through children and pass correct information to each."""
- for m in self:
- if isinstance(m, ResnetBlock2D):
- x = m(x, temb)
- elif isinstance(m, Transformer2DModel):
- x = m(x, cemb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs).sample
- elif isinstance(m, Downsample2D):
- x = m(x)
- elif isinstance(m, Upsample2D):
- x = m(x)
- else:
- raise ValueError(
- f"Type of m is {type(m)} but should be `ResnetBlock2D`, `Transformer2DModel`, `Downsample2D` or `Upsample2D`"
- )
-
- return x
-
-
-def adjust_time_dims(unet: UNet2DConditionModel, in_dim: int, out_dim: int):
- unet.time_embedding.linear_1 = nn.Linear(in_dim, out_dim)
-
-
-def increase_block_input_in_encoder_resnet(unet: UNet2DConditionModel, block_no, resnet_idx, by):
- """Increase channels sizes to allow for additional concatted information from base model"""
- r = unet.down_blocks[block_no].resnets[resnet_idx]
- old_norm1, old_conv1 = r.norm1, r.conv1
- # norm
- norm_args = "num_groups num_channels eps affine".split(" ")
- for a in norm_args:
- assert hasattr(old_norm1, a)
- norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
- norm_kwargs["num_channels"] += by # surgery done here
- # conv1
- conv1_args = [
- "in_channels",
- "out_channels",
- "kernel_size",
- "stride",
- "padding",
- "dilation",
- "groups",
- "bias",
- "padding_mode",
- ]
- if not USE_PEFT_BACKEND:
- conv1_args.append("lora_layer")
-
- for a in conv1_args:
- assert hasattr(old_conv1, a)
-
- conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
- conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
- conv1_kwargs["in_channels"] += by # surgery done here
- # conv_shortcut
- # as we changed the input size of the block, the input and output sizes are likely different,
- # therefore we need a conv_shortcut (simply adding won't work)
- conv_shortcut_args_kwargs = {
- "in_channels": conv1_kwargs["in_channels"],
- "out_channels": conv1_kwargs["out_channels"],
- # default arguments from resnet.__init__
- "kernel_size": 1,
- "stride": 1,
- "padding": 0,
- "bias": True,
- }
- # swap old with new modules
- unet.down_blocks[block_no].resnets[resnet_idx].norm1 = GroupNorm(**norm_kwargs)
- unet.down_blocks[block_no].resnets[resnet_idx].conv1 = (
- nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
- )
- unet.down_blocks[block_no].resnets[resnet_idx].conv_shortcut = (
- nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
- )
- unet.down_blocks[block_no].resnets[resnet_idx].in_channels += by # surgery done here
-
-
-def increase_block_input_in_encoder_downsampler(unet: UNet2DConditionModel, block_no, by):
- """Increase channels sizes to allow for additional concatted information from base model"""
- old_down = unet.down_blocks[block_no].downsamplers[0].conv
-
- args = [
- "in_channels",
- "out_channels",
- "kernel_size",
- "stride",
- "padding",
- "dilation",
- "groups",
- "bias",
- "padding_mode",
- ]
- if not USE_PEFT_BACKEND:
- args.append("lora_layer")
-
- for a in args:
- assert hasattr(old_down, a)
- kwargs = {a: getattr(old_down, a) for a in args}
- kwargs["bias"] = "bias" in kwargs # as param, bias is a boolean, but as attr, it's a tensor.
- kwargs["in_channels"] += by # surgery done here
- # swap old with new modules
- unet.down_blocks[block_no].downsamplers[0].conv = (
- nn.Conv2d(**kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**kwargs)
- )
- unet.down_blocks[block_no].downsamplers[0].channels += by # surgery done here
-
-
-def increase_block_input_in_mid_resnet(unet: UNet2DConditionModel, by):
- """Increase channels sizes to allow for additional concatted information from base model"""
- m = unet.mid_block.resnets[0]
- old_norm1, old_conv1 = m.norm1, m.conv1
- # norm
- norm_args = "num_groups num_channels eps affine".split(" ")
- for a in norm_args:
- assert hasattr(old_norm1, a)
- norm_kwargs = {a: getattr(old_norm1, a) for a in norm_args}
- norm_kwargs["num_channels"] += by # surgery done here
- conv1_args = [
- "in_channels",
- "out_channels",
- "kernel_size",
- "stride",
- "padding",
- "dilation",
- "groups",
- "bias",
- "padding_mode",
- ]
- if not USE_PEFT_BACKEND:
- conv1_args.append("lora_layer")
-
- conv1_kwargs = {a: getattr(old_conv1, a) for a in conv1_args}
- conv1_kwargs["bias"] = "bias" in conv1_kwargs # as param, bias is a boolean, but as attr, it's a tensor.
- conv1_kwargs["in_channels"] += by # surgery done here
- # conv_shortcut
- # as we changed the input size of the block, the input and output sizes are likely different,
- # therefore we need a conv_shortcut (simply adding won't work)
- conv_shortcut_args_kwargs = {
- "in_channels": conv1_kwargs["in_channels"],
- "out_channels": conv1_kwargs["out_channels"],
- # default arguments from resnet.__init__
- "kernel_size": 1,
- "stride": 1,
- "padding": 0,
- "bias": True,
- }
- # swap old with new modules
- unet.mid_block.resnets[0].norm1 = GroupNorm(**norm_kwargs)
- unet.mid_block.resnets[0].conv1 = (
- nn.Conv2d(**conv1_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv1_kwargs)
- )
- unet.mid_block.resnets[0].conv_shortcut = (
- nn.Conv2d(**conv_shortcut_args_kwargs) if USE_PEFT_BACKEND else LoRACompatibleConv(**conv_shortcut_args_kwargs)
- )
- unet.mid_block.resnets[0].in_channels += by # surgery done here
-
-
-def adjust_group_norms(unet: UNet2DConditionModel, max_num_group: int = 32):
- def find_denominator(number, start):
- if start >= number:
- return number
- while start != 0:
- residual = number % start
- if residual == 0:
- return start
- start -= 1
-
- for block in [*unet.down_blocks, unet.mid_block]:
- # resnets
- for r in block.resnets:
- if r.norm1.num_groups < max_num_group:
- r.norm1.num_groups = find_denominator(r.norm1.num_channels, start=max_num_group)
-
- if r.norm2.num_groups < max_num_group:
- r.norm2.num_groups = find_denominator(r.norm2.num_channels, start=max_num_group)
-
- # transformers
- if hasattr(block, "attentions"):
- for a in block.attentions:
- if a.norm.num_groups < max_num_group:
- a.norm.num_groups = find_denominator(a.norm.num_channels, start=max_num_group)
-
-
-def is_iterable(o):
- if isinstance(o, str):
- return False
- try:
- iter(o)
- return True
- except TypeError:
- return False
-
-
-def to_sub_blocks(blocks):
- if not is_iterable(blocks):
- blocks = [blocks]
-
- sub_blocks = []
-
- for b in blocks:
- if hasattr(b, "resnets"):
- if hasattr(b, "attentions") and b.attentions is not None:
- for r, a in zip(b.resnets, b.attentions):
- sub_blocks.append([r, a])
-
- num_resnets = len(b.resnets)
- num_attns = len(b.attentions)
-
- if num_resnets > num_attns:
- # we can have more resnets than attentions, so add each resnet as separate subblock
- for i in range(num_attns, num_resnets):
- sub_blocks.append([b.resnets[i]])
- else:
- for r in b.resnets:
- sub_blocks.append([r])
-
- # upsamplers are part of the same subblock
- if hasattr(b, "upsamplers") and b.upsamplers is not None:
- for u in b.upsamplers:
- sub_blocks[-1].extend([u])
-
- # downsamplers are own subblock
- if hasattr(b, "downsamplers") and b.downsamplers is not None:
- for d in b.downsamplers:
- sub_blocks.append([d])
-
- return list(map(SubBlock, sub_blocks))
-
-
-def zero_module(module):
- for p in module.parameters():
- nn.init.zeros_(p)
- return module
diff --git a/examples/research_projects/controlnetxs/infer_sd_controlnetxs.py b/examples/research_projects/controlnetxs/infer_sd_controlnetxs.py
deleted file mode 100644
index 722b282a3251..000000000000
--- a/examples/research_projects/controlnetxs/infer_sd_controlnetxs.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# !pip install opencv-python transformers accelerate
-import argparse
-
-import cv2
-import numpy as np
-import torch
-from controlnetxs import ControlNetXSModel
-from PIL import Image
-from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
-
-from diffusers.utils import load_image
-
-
-parser = argparse.ArgumentParser()
-parser.add_argument(
- "--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
-)
-parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
-parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
-parser.add_argument(
- "--image_path",
- type=str,
- default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
-)
-parser.add_argument("--num_inference_steps", type=int, default=50)
-
-args = parser.parse_args()
-
-prompt = args.prompt
-negative_prompt = args.negative_prompt
-# download an image
-image = load_image(args.image_path)
-
-# initialize the models and pipeline
-controlnet_conditioning_scale = args.controlnet_conditioning_scale
-controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16)
-pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-1", controlnet=controlnet, torch_dtype=torch.float16
-)
-pipe.enable_model_cpu_offload()
-
-# get canny image
-image = np.array(image)
-image = cv2.Canny(image, 100, 200)
-image = image[:, :, None]
-image = np.concatenate([image, image, image], axis=2)
-canny_image = Image.fromarray(image)
-
-num_inference_steps = args.num_inference_steps
-
-# generate image
-image = pipe(
- prompt,
- controlnet_conditioning_scale=controlnet_conditioning_scale,
- image=canny_image,
- num_inference_steps=num_inference_steps,
-).images[0]
-image.save("cnxs_sd.canny.png")
diff --git a/examples/research_projects/controlnetxs/infer_sdxl_controlnetxs.py b/examples/research_projects/controlnetxs/infer_sdxl_controlnetxs.py
deleted file mode 100644
index e5b8cfd88223..000000000000
--- a/examples/research_projects/controlnetxs/infer_sdxl_controlnetxs.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# !pip install opencv-python transformers accelerate
-import argparse
-
-import cv2
-import numpy as np
-import torch
-from controlnetxs import ControlNetXSModel
-from PIL import Image
-from pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
-
-from diffusers.utils import load_image
-
-
-parser = argparse.ArgumentParser()
-parser.add_argument(
- "--prompt", type=str, default="aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
-)
-parser.add_argument("--negative_prompt", type=str, default="low quality, bad quality, sketches")
-parser.add_argument("--controlnet_conditioning_scale", type=float, default=0.7)
-parser.add_argument(
- "--image_path",
- type=str,
- default="https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png",
-)
-parser.add_argument("--num_inference_steps", type=int, default=50)
-
-args = parser.parse_args()
-
-prompt = args.prompt
-negative_prompt = args.negative_prompt
-# download an image
-image = load_image(args.image_path)
-# initialize the models and pipeline
-controlnet_conditioning_scale = args.controlnet_conditioning_scale
-controlnet = ControlNetXSModel.from_pretrained("UmerHA/ConrolNetXS-SDXL-canny", torch_dtype=torch.float16)
-pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
-)
-pipe.enable_model_cpu_offload()
-
-# get canny image
-image = np.array(image)
-image = cv2.Canny(image, 100, 200)
-image = image[:, :, None]
-image = np.concatenate([image, image, image], axis=2)
-canny_image = Image.fromarray(image)
-
-num_inference_steps = args.num_inference_steps
-
-# generate image
-image = pipe(
- prompt,
- controlnet_conditioning_scale=controlnet_conditioning_scale,
- image=canny_image,
- num_inference_steps=num_inference_steps,
-).images[0]
-image.save("cnxs_sdxl.canny.png")
diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
index 4bb6b894476a..3cec037e2544 100644
--- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
+++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
@@ -484,6 +484,10 @@ def main(args):
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
index 24d51658e3b0..0297a06f5b2c 100644
--- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
+++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
@@ -526,6 +526,10 @@ def main(args):
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
index 2fcddad8b63f..cdc096190f08 100644
--- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
+++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
@@ -516,6 +516,10 @@ def main(args):
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
index fc4e458f90ea..cd1ef265d23e 100644
--- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
+++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
@@ -623,6 +623,10 @@ def main(args):
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
index 436706c8512d..997d448fa281 100644
--- a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
+++ b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
@@ -21,6 +21,7 @@
import math
import os
import shutil
+from contextlib import nullcontext
from pathlib import Path
import accelerate
@@ -410,6 +411,10 @@ def main():
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.report_to == "wandb":
@@ -967,9 +972,12 @@ def collate_fn(examples):
# run inference
original_image = download_image(args.val_image_url)
edited_images = []
- with torch.autocast(
- str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
- ):
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
diff --git a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
index 3cfd72821490..ea4a0d255b68 100644
--- a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
+++ b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
@@ -378,6 +378,10 @@ def main():
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py
index 462c3bbd44cf..cf00bf270057 100644
--- a/examples/research_projects/lora/train_text_to_image_lora.py
+++ b/examples/research_projects/lora/train_text_to_image_lora.py
@@ -411,6 +411,11 @@ def main():
log_with=args.report_to,
project_config=accelerator_project_config,
)
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
index 53b70abd0115..0f507b26d6a8 100644
--- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
+++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
@@ -698,6 +698,10 @@ def main(args):
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
diff --git a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
index 5fab1b6e9cbc..57ad77477b0d 100644
--- a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
+++ b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
@@ -566,6 +566,10 @@ def main():
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
index 2045ef4197c1..ee61f033d34d 100644
--- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
+++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
@@ -439,6 +439,10 @@ def main():
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
index 5d774d591d9a..e10564fa59ef 100644
--- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
+++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
@@ -581,6 +581,10 @@ def main():
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
index cba2c5117bfe..9a00f7cc4a9a 100644
--- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
+++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
@@ -295,6 +295,10 @@ def main(args):
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
if args.logger == "tensorboard":
if not is_tensorboard_available():
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
diff --git a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
index 8c454e91b2db..dcbc2704b833 100644
--- a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
+++ b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
@@ -789,7 +789,12 @@ def prepare_image(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/research_projects/rdm/pipeline_rdm.py b/examples/research_projects/rdm/pipeline_rdm.py
index dd97bf71b9db..e0c4847c7e39 100644
--- a/examples/research_projects/rdm/pipeline_rdm.py
+++ b/examples/research_projects/rdm/pipeline_rdm.py
@@ -123,7 +123,12 @@ def _encode_image(self, retrieved_images, batch_size):
return image_embeddings
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/examples/research_projects/scheduled_huber_loss_training/README.md b/examples/research_projects/scheduled_huber_loss_training/README.md
new file mode 100644
index 000000000000..239f94ba1005
--- /dev/null
+++ b/examples/research_projects/scheduled_huber_loss_training/README.md
@@ -0,0 +1,15 @@
+# Scheduled Pseudo-Huber Loss for Diffusers
+
+These are the modifications of to include the possibility of training text2image models with Scheduled Pseudo Huber loss, introduced in https://arxiv.org/abs/2403.16728. (https://github.com/kabachuha/SPHL-for-stable-diffusion)
+
+## Why this might be useful?
+
+- If you suspect that the part of the training dataset might be corrupted, and you don't want these outliers to distort the model's supposed output
+
+- If you want to improve the aesthetic quality of pictures by helping the model disentangle concepts and be less influenced by another sorts of pictures.
+
+See https://github.com/huggingface/diffusers/issues/7488 for the detailed description.
+
+## Instructions
+
+The same usage as in the case of the corresponding vanilla Diffusers scripts https://github.com/huggingface/diffusers/tree/main/examples
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
new file mode 100644
index 000000000000..779da7328d97
--- /dev/null
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
@@ -0,0 +1,1518 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import gc
+import importlib
+import itertools
+import logging
+import math
+import os
+import shutil
+import warnings
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, model_info, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.28.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images: list = None,
+ base_model: str = None,
+ train_text_encoder=False,
+ prompt: str = None,
+ repo_folder: str = None,
+ pipeline: DiffusionPipeline = None,
+):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ model_description = f"""
+# DreamBooth - {repo_id}
+
+This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
+You can find some example images in the following. \n
+{img_str}
+
+DreamBooth for the text encoder was enabled: {train_text_encoder}.
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ prompt=prompt,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = ["text-to-image", "dreambooth", "diffusers-training"]
+ if isinstance(pipeline, StableDiffusionPipeline):
+ tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
+ else:
+ tags.extend(["if", "if-diffusers"])
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ text_encoder,
+ tokenizer,
+ unet,
+ vae,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ prompt_embeds,
+ negative_prompt_embeds,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+
+ pipeline_args = {}
+
+ if vae is not None:
+ pipeline_args["vae"] = vae
+
+ # create pipeline (note: unet and vae are loaded again in float32)
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ unet=unet,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ **pipeline_args,
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ module = importlib.import_module("diffusers")
+ scheduler_class = getattr(module, args.validation_scheduler)
+ pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args)
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.pre_compute_text_embeddings:
+ pipeline_args = {
+ "prompt_embeds": prompt_embeds,
+ "negative_prompt_embeds": negative_prompt_embeds,
+ }
+ else:
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ # run inference
+ generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ images = []
+ if args.validation_images is None:
+ for _ in range(args.num_validation_images):
+ with torch.autocast("cuda"):
+ image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
+ images.append(image)
+ else:
+ for image in args.validation_images:
+ image = Image.open(image)
+ image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ elif model_class == "T5EncoderModel":
+ from transformers import T5EncoderModel
+
+ return T5EncoderModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="dreambooth-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=(
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
+ " for more details"
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=100,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+
+ parser.add_argument(
+ "--offset_noise",
+ action="store_true",
+ default=False,
+ help=(
+ "Fine-tuning against a modified noise"
+ " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
+ ),
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--pre_compute_text_embeddings",
+ action="store_true",
+ help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
+ )
+ parser.add_argument(
+ "--tokenizer_max_length",
+ type=int,
+ default=None,
+ required=False,
+ help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
+ )
+ parser.add_argument(
+ "--text_encoder_use_attention_mask",
+ action="store_true",
+ required=False,
+ help="Whether to use attention mask for the text encoder",
+ )
+ parser.add_argument(
+ "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder"
+ )
+ parser.add_argument(
+ "--validation_images",
+ required=False,
+ default=None,
+ nargs="+",
+ help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
+ )
+ parser.add_argument(
+ "--class_labels_conditioning",
+ required=False,
+ default=None,
+ help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber", "smooth_l1"],
+ help="The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.",
+ )
+ parser.add_argument(
+ "--huber_schedule",
+ type=str,
+ default="snr",
+ choices=["constant", "exponential", "snr"],
+ help="The schedule to use for the huber losses parameter",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.1,
+ help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
+ )
+ parser.add_argument(
+ "--validation_scheduler",
+ type=str,
+ default="DPMSolverMultistepScheduler",
+ choices=["DPMSolverMultistepScheduler", "DDPMScheduler"],
+ help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ if args.train_text_encoder and args.pre_compute_text_embeddings:
+ raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ class_num=None,
+ size=512,
+ center_crop=False,
+ encoder_hidden_states=None,
+ class_prompt_encoder_hidden_states=None,
+ tokenizer_max_length=None,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+ self.encoder_hidden_states = encoder_hidden_states
+ self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
+ self.tokenizer_max_length = tokenizer_max_length
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ instance_image = exif_transpose(instance_image)
+
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+
+ if self.encoder_hidden_states is not None:
+ example["instance_prompt_ids"] = self.encoder_hidden_states
+ else:
+ text_inputs = tokenize_prompt(
+ self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["instance_prompt_ids"] = text_inputs.input_ids
+ example["instance_attention_mask"] = text_inputs.attention_mask
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+
+ if self.class_prompt_encoder_hidden_states is not None:
+ example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
+ else:
+ class_text_inputs = tokenize_prompt(
+ self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["class_prompt_ids"] = class_text_inputs.input_ids
+ example["class_attention_mask"] = class_text_inputs.attention_mask
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ has_attention_mask = "instance_attention_mask" in examples[0]
+
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ if has_attention_mask:
+ attention_mask = [example["instance_attention_mask"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+
+ if has_attention_mask:
+ attention_mask += [example["class_attention_mask"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.cat(input_ids, dim=0)
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+
+ if has_attention_mask:
+ attention_mask = torch.cat(attention_mask, dim=0)
+ batch["attention_mask"] = attention_mask
+
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def model_has_vae(args):
+ config_file_name = os.path.join("vae", AutoencoderKL.config_name)
+ if os.path.isdir(args.pretrained_model_name_or_path):
+ config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
+ return os.path.isfile(config_file_name)
+ else:
+ files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
+ return any(file.rfilename == config_file_name for file in files_in_repo)
+
+
+def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
+ if tokenizer_max_length is not None:
+ max_length = tokenizer_max_length
+ else:
+ max_length = tokenizer.model_max_length
+
+ text_inputs = tokenizer(
+ prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=max_length,
+ return_tensors="pt",
+ )
+
+ return text_inputs
+
+
+def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
+ text_input_ids = input_ids.to(text_encoder.device)
+
+ if text_encoder_use_attention_mask:
+ attention_mask = attention_mask.to(text_encoder.device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = text_encoder(
+ text_input_ids,
+ attention_mask=attention_mask,
+ return_dict=False,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ return prompt_embeds
+
+
+# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
+def conditional_loss(
+ model_pred: torch.Tensor,
+ target: torch.Tensor,
+ reduction: str = "mean",
+ loss_type: str = "l2",
+ huber_c: float = 0.1,
+):
+ if loss_type == "l2":
+ loss = F.mse_loss(model_pred, target, reduction=reduction)
+ elif loss_type == "huber":
+ loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ elif loss_type == "smooth_l1":
+ loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
+ return loss
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
+ raise ValueError(
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ safety_checker=None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+
+ if model_has_vae(args):
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+ else:
+ vae = None
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for model in models:
+ sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder"
+ model.save_pretrained(os.path.join(output_dir, sub_dir))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(text_encoder))):
+ # load transformers style into model
+ load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
+ model.config = load_model.config
+ else:
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if vae is not None:
+ vae.requires_grad_(False)
+
+ if not args.train_text_encoder:
+ text_encoder.requires_grad_(False)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder.gradient_checkpointing_enable()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ "Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training. copy of the weights should still be float32."
+ )
+
+ if unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}")
+
+ if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
+ raise ValueError(
+ f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = (
+ itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
+ )
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.pre_compute_text_embeddings:
+
+ def compute_text_embeddings(prompt):
+ with torch.no_grad():
+ text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
+ prompt_embeds = encode_prompt(
+ text_encoder,
+ text_inputs.input_ids,
+ text_inputs.attention_mask,
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ return prompt_embeds
+
+ pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
+ validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
+
+ if args.validation_prompt is not None:
+ validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
+ else:
+ validation_prompt_encoder_hidden_states = None
+
+ if args.class_prompt is not None:
+ pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
+ else:
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ text_encoder = None
+ tokenizer = None
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ else:
+ pre_computed_encoder_hidden_states = None
+ validation_prompt_encoder_hidden_states = None
+ validation_prompt_negative_prompt_embeds = None
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ class_num=args.num_class_images,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ encoder_hidden_states=pre_computed_encoder_hidden_states,
+ class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
+ tokenizer_max_length=args.tokenizer_max_length,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move vae and text_encoder to device and cast to weight_dtype
+ if vae is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ if not args.train_text_encoder and text_encoder is not None:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = vars(copy.deepcopy(args))
+ tracker_config.pop("validation_images")
+ accelerator.init_trackers("dreambooth", config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+
+ if vae is not None:
+ # Convert images to latent space
+ model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ else:
+ model_input = pixel_values
+
+ # Sample noise that we'll add to the model input
+ if args.offset_noise:
+ noise = torch.randn_like(model_input) + 0.1 * torch.randn(
+ model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device
+ )
+ else:
+ noise = torch.randn_like(model_input)
+ bsz, channels, height, width = model_input.shape
+ # Sample a random timestep for each image
+ if args.loss_type == "huber" or args.loss_type == "smooth_l1":
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device="cpu")
+ timestep = timesteps.item()
+
+ if args.huber_schedule == "exponential":
+ alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
+ huber_c = math.exp(-alpha * timestep)
+ elif args.huber_schedule == "snr":
+ alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
+ sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
+ elif args.huber_schedule == "constant":
+ huber_c = args.huber_c
+ else:
+ raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
+
+ timesteps = timesteps.repeat(bsz).to(model_input.device)
+ elif args.loss_type == "l2":
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ huber_c = 1 # may be anything, as it's not used
+ else:
+ raise NotImplementedError(f"Unknown loss type {args.loss_type}")
+
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ if args.pre_compute_text_embeddings:
+ encoder_hidden_states = batch["input_ids"]
+ else:
+ encoder_hidden_states = encode_prompt(
+ text_encoder,
+ batch["input_ids"],
+ batch["attention_mask"],
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ if unwrap_model(unet).config.in_channels == channels * 2:
+ noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
+
+ if args.class_labels_conditioning == "timesteps":
+ class_labels = timesteps
+ else:
+ class_labels = None
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False
+ )[0]
+
+ if model_pred.shape[1] == 6:
+ model_pred, _ = torch.chunk(model_pred, 2, dim=1)
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+ # Compute prior loss
+ prior_loss = conditional_loss(
+ model_pred_prior.float(),
+ target_prior.float(),
+ reduction="mean",
+ loss_type=args.loss_type,
+ huber_c=huber_c,
+ )
+
+ # Compute instance loss
+ if args.snr_gamma is None:
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
+ )
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ base_weight = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective needs to be floored to an SNR weight of one.
+ mse_loss_weights = base_weight + 1
+ else:
+ # Epsilon and sample both use the same loss weights.
+ mse_loss_weights = base_weight
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
+ )
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet.parameters(), text_encoder.parameters())
+ if args.train_text_encoder
+ else unet.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ images = []
+
+ if args.validation_prompt is not None and global_step % args.validation_steps == 0:
+ images = log_validation(
+ unwrap_model(text_encoder) if text_encoder is not None else text_encoder,
+ tokenizer,
+ unwrap_model(unet),
+ vae,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ validation_prompt_encoder_hidden_states,
+ validation_prompt_negative_prompt_embeds,
+ )
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ pipeline_args = {}
+
+ if text_encoder is not None:
+ pipeline_args["text_encoder"] = unwrap_model(text_encoder)
+
+ if args.skip_save_text_encoder:
+ pipeline_args["text_encoder"] = None
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ **pipeline_args,
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline.save_pretrained(args.output_dir)
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ prompt=args.instance_prompt,
+ repo_folder=args.output_dir,
+ pipeline=pipeline,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
new file mode 100644
index 000000000000..73c67dd5ddf7
--- /dev/null
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
@@ -0,0 +1,1504 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import gc
+import logging
+import math
+import os
+import shutil
+import warnings
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from peft import LoraConfig
+from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DiffusionPipeline,
+ DPMSolverMultistepScheduler,
+ StableDiffusionPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import LoraLoaderMixin
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params
+from diffusers.utils import (
+ check_min_version,
+ convert_state_dict_to_diffusers,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.28.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model=str,
+ train_text_encoder=False,
+ prompt=str,
+ repo_folder=None,
+ pipeline: DiffusionPipeline = None,
+):
+ img_str = ""
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ model_description = f"""
+# LoRA DreamBooth - {repo_id}
+
+These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
+{img_str}
+
+LoRA for the text encoder was enabled: {train_text_encoder}.
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ prompt=prompt,
+ model_description=model_description,
+ inference=True,
+ )
+ tags = ["text-to-image", "diffusers", "lora", "diffusers-training"]
+ if isinstance(pipeline, StableDiffusionPipeline):
+ tags.extend(["stable-diffusion", "stable-diffusion-diffusers"])
+ else:
+ tags.extend(["if", "if-diffusers"])
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+
+ if args.validation_images is None:
+ images = []
+ for _ in range(args.num_validation_images):
+ with torch.cuda.amp.autocast():
+ image = pipeline(**pipeline_args, generator=generator).images[0]
+ images.append(image)
+ else:
+ images = []
+ for image in args.validation_images:
+ image = Image.open(image)
+ with torch.cuda.amp.autocast():
+ image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ elif model_class == "T5EncoderModel":
+ from transformers import T5EncoderModel
+
+ return T5EncoderModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ required=True,
+ help="A folder containing the training data of instance images.",
+ )
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lora-dreambooth-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--pre_compute_text_embeddings",
+ action="store_true",
+ help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
+ )
+ parser.add_argument(
+ "--tokenizer_max_length",
+ type=int,
+ default=None,
+ required=False,
+ help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
+ )
+ parser.add_argument(
+ "--text_encoder_use_attention_mask",
+ action="store_true",
+ required=False,
+ help="Whether to use attention mask for the text encoder",
+ )
+ parser.add_argument(
+ "--validation_images",
+ required=False,
+ default=None,
+ nargs="+",
+ help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
+ )
+ parser.add_argument(
+ "--class_labels_conditioning",
+ required=False,
+ default=None,
+ help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber", "smooth_l1"],
+ help="The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.",
+ )
+ parser.add_argument(
+ "--huber_schedule",
+ type=str,
+ default="snr",
+ choices=["constant", "exponential", "snr"],
+ help="The schedule to use for the huber losses parameter",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.1,
+ help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ if args.train_text_encoder and args.pre_compute_text_embeddings:
+ raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images and the tokenizes prompts.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ tokenizer,
+ class_data_root=None,
+ class_prompt=None,
+ class_num=None,
+ size=512,
+ center_crop=False,
+ encoder_hidden_states=None,
+ class_prompt_encoder_hidden_states=None,
+ tokenizer_max_length=None,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+ self.tokenizer = tokenizer
+ self.encoder_hidden_states = encoder_hidden_states
+ self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
+ self.tokenizer_max_length = tokenizer_max_length
+
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
+ self.num_instance_images = len(self.instance_images_path)
+ self.instance_prompt = instance_prompt
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ self.class_prompt = class_prompt
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
+ instance_image = exif_transpose(instance_image)
+
+ if not instance_image.mode == "RGB":
+ instance_image = instance_image.convert("RGB")
+ example["instance_images"] = self.image_transforms(instance_image)
+
+ if self.encoder_hidden_states is not None:
+ example["instance_prompt_ids"] = self.encoder_hidden_states
+ else:
+ text_inputs = tokenize_prompt(
+ self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["instance_prompt_ids"] = text_inputs.input_ids
+ example["instance_attention_mask"] = text_inputs.attention_mask
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+
+ if self.class_prompt_encoder_hidden_states is not None:
+ example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
+ else:
+ class_text_inputs = tokenize_prompt(
+ self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
+ )
+ example["class_prompt_ids"] = class_text_inputs.input_ids
+ example["class_attention_mask"] = class_text_inputs.attention_mask
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ has_attention_mask = "instance_attention_mask" in examples[0]
+
+ input_ids = [example["instance_prompt_ids"] for example in examples]
+ pixel_values = [example["instance_images"] for example in examples]
+
+ if has_attention_mask:
+ attention_mask = [example["instance_attention_mask"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ input_ids += [example["class_prompt_ids"] for example in examples]
+ pixel_values += [example["class_images"] for example in examples]
+ if has_attention_mask:
+ attention_mask += [example["class_attention_mask"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ input_ids = torch.cat(input_ids, dim=0)
+
+ batch = {
+ "input_ids": input_ids,
+ "pixel_values": pixel_values,
+ }
+
+ if has_attention_mask:
+ batch["attention_mask"] = attention_mask
+
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
+ if tokenizer_max_length is not None:
+ max_length = tokenizer_max_length
+ else:
+ max_length = tokenizer.model_max_length
+
+ text_inputs = tokenizer(
+ prompt,
+ truncation=True,
+ padding="max_length",
+ max_length=max_length,
+ return_tensors="pt",
+ )
+
+ return text_inputs
+
+
+def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
+ text_input_ids = input_ids.to(text_encoder.device)
+
+ if text_encoder_use_attention_mask:
+ attention_mask = attention_mask.to(text_encoder.device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = text_encoder(
+ text_input_ids,
+ attention_mask=attention_mask,
+ return_dict=False,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ return prompt_embeds
+
+
+# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
+def conditional_loss(
+ model_pred: torch.Tensor,
+ target: torch.Tensor,
+ reduction: str = "mean",
+ loss_type: str = "l2",
+ huber_c: float = 0.1,
+):
+ if loss_type == "l2":
+ loss = F.mse_loss(model_pred, target, reduction=reduction)
+ elif loss_type == "huber":
+ loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ elif loss_type == "smooth_l1":
+ loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
+ return loss
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
+ # TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.
+ if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
+ raise ValueError(
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ safety_checker=None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizer
+ if args.tokenizer_name:
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
+ elif args.pretrained_model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder class
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder = text_encoder_cls.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ try:
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+ except OSError:
+ # IF does not have a VAE so let's just set it to None
+ # We don't have to error out here
+ vae = None
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ if vae is not None:
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ if vae is not None:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder.gradient_checkpointing_enable()
+
+ # now we will add new LoRA weights to the attention layers
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
+ )
+ unet.add_adapter(unet_lora_config)
+
+ # The text encoder comes from 🤗 transformers, we will also attach adapters to it.
+ if args.train_text_encoder:
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder.add_adapter(text_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+ text_encoder_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
+ elif isinstance(model, type(unwrap_model(text_encoder))):
+ text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ LoraLoaderMixin.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+ text_encoder_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder))):
+ text_encoder_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
+
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ if args.train_text_encoder:
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_)
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [unet_]
+ if args.train_text_encoder:
+ models.append(text_encoder_)
+
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [unet]
+ if args.train_text_encoder:
+ models.append(text_encoder)
+
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
+ if args.train_text_encoder:
+ params_to_optimize = params_to_optimize + list(filter(lambda p: p.requires_grad, text_encoder.parameters()))
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.pre_compute_text_embeddings:
+
+ def compute_text_embeddings(prompt):
+ with torch.no_grad():
+ text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
+ prompt_embeds = encode_prompt(
+ text_encoder,
+ text_inputs.input_ids,
+ text_inputs.attention_mask,
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ return prompt_embeds
+
+ pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
+ validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
+
+ if args.validation_prompt is not None:
+ validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
+ else:
+ validation_prompt_encoder_hidden_states = None
+
+ if args.class_prompt is not None:
+ pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
+ else:
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ text_encoder = None
+ tokenizer = None
+
+ gc.collect()
+ torch.cuda.empty_cache()
+ else:
+ pre_computed_encoder_hidden_states = None
+ validation_prompt_encoder_hidden_states = None
+ validation_prompt_negative_prompt_embeds = None
+ pre_computed_class_prompt_encoder_hidden_states = None
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_prompt=args.class_prompt,
+ class_num=args.num_class_images,
+ tokenizer=tokenizer,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ encoder_hidden_states=pre_computed_encoder_hidden_states,
+ class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
+ tokenizer_max_length=args.tokenizer_max_length,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = vars(copy.deepcopy(args))
+ tracker_config.pop("validation_images")
+ accelerator.init_trackers("dreambooth-lora", config=tracker_config)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder.train()
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+
+ if vae is not None:
+ # Convert images to latent space
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ else:
+ model_input = pixel_values
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz, channels, height, width = model_input.shape
+ # Sample a random timestep for each image
+ if args.loss_type == "huber" or args.loss_type == "smooth_l1":
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device="cpu")
+ timestep = timesteps.item()
+
+ if args.huber_schedule == "exponential":
+ alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
+ huber_c = math.exp(-alpha * timestep)
+ elif args.huber_schedule == "snr":
+ alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
+ sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
+ elif args.huber_schedule == "constant":
+ huber_c = args.huber_c
+ else:
+ raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
+
+ timesteps = timesteps.repeat(bsz).to(model_input.device)
+ elif args.loss_type == "l2":
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ huber_c = 1 # may be anything, as it's not used
+ else:
+ raise NotImplementedError(f"Unknown loss type {args.loss_type}")
+
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ if args.pre_compute_text_embeddings:
+ encoder_hidden_states = batch["input_ids"]
+ else:
+ encoder_hidden_states = encode_prompt(
+ text_encoder,
+ batch["input_ids"],
+ batch["attention_mask"],
+ text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
+ )
+
+ if unwrap_model(unet).config.in_channels == channels * 2:
+ noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
+
+ if args.class_labels_conditioning == "timesteps":
+ class_labels = timesteps
+ else:
+ class_labels = None
+
+ # Predict the noise residual
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ encoder_hidden_states,
+ class_labels=class_labels,
+ return_dict=False,
+ )[0]
+
+ # if model predicts variance, throw away the prediction. we will only train on the
+ # simplified training objective. This means that all schedulers using the fine tuned
+ # model must be configured to use one of the fixed variance variance types.
+ if model_pred.shape[1] == 6:
+ model_pred, _ = torch.chunk(model_pred, 2, dim=1)
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute instance loss
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
+ )
+
+ # Compute prior loss
+ prior_loss = conditional_loss(
+ model_pred_prior.float(),
+ target_prior.float(),
+ reduction="mean",
+ loss_type=args.loss_type,
+ huber_c=huber_c,
+ )
+
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+ else:
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
+ )
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unwrap_model(unet),
+ text_encoder=None if args.pre_compute_text_embeddings else unwrap_model(text_encoder),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ if args.pre_compute_text_embeddings:
+ pipeline_args = {
+ "prompt_embeds": validation_prompt_encoder_hidden_states,
+ "negative_prompt_embeds": validation_prompt_negative_prompt_embeds,
+ }
+ else:
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ )
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ unet = unet.to(torch.float32)
+
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+
+ if args.train_text_encoder:
+ text_encoder = unwrap_model(text_encoder)
+ text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))
+ else:
+ text_encoder_state_dict = None
+
+ LoraLoaderMixin.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_state_dict,
+ text_encoder_lora_layers=text_encoder_state_dict,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
+ )
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ prompt=args.instance_prompt,
+ repo_folder=args.output_dir,
+ pipeline=pipeline,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
new file mode 100644
index 000000000000..0004e6d6e87a
--- /dev/null
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
@@ -0,0 +1,2078 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import contextlib
+import gc
+import itertools
+import json
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, hf_hub_download, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from packaging import version
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from safetensors.torch import load_file, save_file
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ DPMSolverMultistepScheduler,
+ EDMEulerScheduler,
+ EulerDiscreteScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import LoraLoaderMixin
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
+from diffusers.utils import (
+ check_min_version,
+ convert_all_state_dict_to_peft,
+ convert_state_dict_to_diffusers,
+ convert_state_dict_to_kohya,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.28.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def determine_scheduler_type(pretrained_model_name_or_path, revision):
+ model_index_filename = "model_index.json"
+ if os.path.isdir(pretrained_model_name_or_path):
+ model_index = os.path.join(pretrained_model_name_or_path, model_index_filename)
+ else:
+ model_index = hf_hub_download(
+ repo_id=pretrained_model_name_or_path, filename=model_index_filename, revision=revision
+ )
+
+ with open(model_index, "r") as f:
+ scheduler_type = json.load(f)["scheduler"][1]
+ return scheduler_type
+
+
+def save_model_card(
+ repo_id: str,
+ use_dora: bool,
+ images=None,
+ base_model: str = None,
+ train_text_encoder=False,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+ vae_path=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# {'SDXL' if 'playground' not in base_model else 'Playground'} LoRA DreamBooth - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} LoRA adaption weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/).
+
+LoRA for the text encoder was enabled: {train_text_encoder}.
+
+Special VAE used for training: {vae_path}.
+
+## Trigger words
+
+You should use {instance_prompt} to trigger the image generation.
+
+## Download model
+
+Weights for this model are available in Safetensors format.
+
+[Download]({repo_id}/tree/main) them in the Files & versions tab.
+
+"""
+ if "playground" in base_model:
+ model_description += """\n
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="openrail++" if "playground" not in base_model else "playground-v2dot5-community",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora" if not use_dora else "dora",
+ "template:sd-lora",
+ ]
+ if "playground" in base_model:
+ tags.extend(["playground", "playground-diffusers"])
+ else:
+ tags.extend(["stable-diffusion-xl", "stable-diffusion-xl-diffusers"])
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+
+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+ scheduler_args = {}
+
+ if not args.do_edm_style_training:
+ if "variance_type" in pipeline.scheduler.config:
+ variance_type = pipeline.scheduler.config.variance_type
+
+ if variance_type in ["learned", "learned_range"]:
+ variance_type = "fixed_small"
+
+ scheduler_args["variance_type"] = variance_type
+
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ # Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
+ # way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
+ inference_ctx = (
+ contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
+ )
+
+ with inference_ctx:
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--do_edm_style_training",
+ default=False,
+ action="store_true",
+ help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
+ )
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lora-dreambooth-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--output_kohya_format",
+ action="store_true",
+ help="Flag to additionally generate final state dict in the Kohya format so that it becomes compatible with A111, Comfy, Kohya, etc.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--use_dora",
+ action="store_true",
+ default=False,
+ help=(
+ "Wether to train a DoRA as proposed in- DoRA: Weight-Decomposed Low-Rank Adaptation https://arxiv.org/abs/2402.09353. "
+ "Note: to use DoRA you need to install peft from main, `pip install git+https://github.com/huggingface/peft.git`"
+ ),
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber", "smooth_l1"],
+ help="The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.",
+ )
+ parser.add_argument(
+ "--huber_schedule",
+ type=str,
+ default="snr",
+ choices=["constant", "exponential", "snr"],
+ help="The schedule to use for the huber losses parameter",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.1,
+ help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ class_data_root=None,
+ class_num=None,
+ size=1024,
+ repeats=1,
+ center_crop=False,
+ ):
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = dataset["train"][image_column]
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ for img in instance_images:
+ self.instance_images.extend(itertools.repeat(img, repeats))
+
+ # image processing to prepare for using SD-XL micro-conditioning
+ self.original_sizes = []
+ self.crop_top_lefts = []
+ self.pixel_values = []
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ for image in self.instance_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ self.original_sizes.append((image.height, image.width))
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ crop_top_left = (y1, x1)
+ self.crop_top_lefts.append(crop_top_left)
+ image = train_transforms(image)
+ self.pixel_values.append(image)
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.pixel_values[index % self.num_instance_images]
+ original_size = self.original_sizes[index % self.num_instance_images]
+ crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+ example["original_size"] = original_size
+ example["crop_top_left"] = crop_top_left
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # costum prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+ original_sizes = [example["original_size"] for example in examples]
+ crop_top_lefts = [example["crop_top_left"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+ original_sizes += [example["original_size"] for example in examples]
+ crop_top_lefts += [example["crop_top_left"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {
+ "pixel_values": pixel_values,
+ "prompts": prompts,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+ return batch
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
+ prompt_embeds_list = []
+
+ for i, text_encoder in enumerate(text_encoders):
+ if tokenizers is not None:
+ tokenizer = tokenizers[i]
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
+ else:
+ assert text_input_ids_list is not None
+ text_input_ids = text_input_ids_list[i]
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds[-1][-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
+def conditional_loss(
+ model_pred: torch.Tensor,
+ target: torch.Tensor,
+ reduction: str = "mean",
+ loss_type: str = "l2",
+ huber_c: float = 0.1,
+ weighting: Optional[torch.Tensor] = None,
+):
+ if loss_type == "l2":
+ if weighting is not None:
+ loss = torch.mean(
+ (weighting * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction=reduction)
+
+ elif loss_type == "huber":
+ if weighting is not None:
+ loss = torch.mean(
+ (
+ 2
+ * huber_c
+ * (
+ torch.sqrt(weighting.float() * (model_pred.float() - target.float()) ** 2 + huber_c**2)
+ - huber_c
+ )
+ ).reshape(target.shape[0], -1),
+ 1,
+ )
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ elif loss_type == "smooth_l1":
+ if weighting is not None:
+ loss = torch.mean(
+ (
+ 2
+ * (
+ torch.sqrt(weighting.float() * (model_pred.float() - target.float()) ** 2 + huber_c**2)
+ - huber_c
+ )
+ ).reshape(target.shape[0], -1),
+ 1,
+ )
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
+ return loss
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if args.do_edm_style_training and args.snr_gamma is not None:
+ raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ torch_dtype=torch_dtype,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ scheduler_type = determine_scheduler_type(args.pretrained_model_name_or_path, args.revision)
+ if "EDM" in scheduler_type:
+ args.do_edm_style_training = True
+ noise_scheduler = EDMEulerScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ logger.info("Performing EDM-style training!")
+ elif args.do_edm_style_training:
+ noise_scheduler = EulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ logger.info("Performing EDM-style training!")
+ else:
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ latents_mean = latents_std = None
+ if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
+ latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+
+ # The VAE is always in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
+ "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+ text_encoder_two.gradient_checkpointing_enable()
+
+ # now we will add new LoRA weights to the attention layers
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ use_dora=args.use_dora,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ unet.add_adapter(unet_lora_config)
+
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
+ if args.train_text_encoder:
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ use_dora=args.use_dora,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder_one.add_adapter(text_lora_config)
+ text_encoder_two.add_adapter(text_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder atten layers
+ unet_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+ text_encoder_two_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
+ text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+ text_encoder_one_ = None
+ text_encoder_two_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
+ text_encoder_two_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
+
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ if args.train_text_encoder:
+ # Do we need to call `scale_lora_layers()` here?
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
+
+ _set_state_dict_into_text_encoder(
+ lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [unet_]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one_, text_encoder_two_])
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [unet]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one, text_encoder_two])
+
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
+
+ if args.train_text_encoder:
+ text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
+ text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
+
+ # Optimization parameters
+ unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
+ if args.train_text_encoder:
+ # different learning rate for text encoder and unet
+ text_lora_parameters_one_with_lr = {
+ "params": text_lora_parameters_one,
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ text_lora_parameters_two_with_lr = {
+ "params": text_lora_parameters_two,
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ params_to_optimize = [
+ unet_lora_parameters_with_lr,
+ text_lora_parameters_one_with_lr,
+ text_lora_parameters_two_with_lr,
+ ]
+ else:
+ params_to_optimize = [unet_lora_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+ if args.train_text_encoder and args.text_encoder_lr:
+ logger.warning(
+ f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
+ f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
+ f"When using prodigy only learning_rate is used as the initial learning rate."
+ )
+ # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
+ # --learning_rate
+ params_to_optimize[1]["lr"] = args.learning_rate
+ params_to_optimize[2]["lr"] = args.learning_rate
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ size=args.resolution,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Computes additional embeddings/ids required by the SDXL UNet.
+ # regular text embeddings (when `train_text_encoder` is not True)
+ # pooled text embeddings
+ # time ids
+
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ if not args.train_text_encoder:
+ tokenizers = [tokenizer_one, tokenizer_two]
+ text_encoders = [text_encoder_one, text_encoder_two]
+
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ return prompt_embeds, pooled_prompt_embeds
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
+ args.instance_prompt, text_encoders, tokenizers
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ if not args.train_text_encoder:
+ class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
+ args.class_prompt, text_encoders, tokenizers
+ )
+
+ # Clear the memory here
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ del tokenizers, text_encoders
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+
+ if not train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds = instance_prompt_hidden_states
+ unet_add_text_embeds = instance_pooled_prompt_embeds
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
+ # if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
+ # batch prompts on all training steps
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
+ tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
+ if args.with_prior_preservation:
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
+ class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = (
+ "dreambooth-lora-sd-xl"
+ if "playground" not in args.pretrained_model_name_or_path
+ else "dreambooth-lora-playground"
+ )
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder_one.train()
+ text_encoder_two.train()
+
+ # set top parameter requires_grad = True for gradient checkpointing works
+ accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
+ accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
+
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ prompts = batch["prompts"]
+
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
+ prompts, text_encoders, tokenizers
+ )
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, prompts)
+ tokens_two = tokenize_prompt(tokenizer_two, prompts)
+
+ # Convert images to latent space
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+
+ if latents_mean is None and latents_std is None:
+ model_input = model_input * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+ else:
+ latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
+ latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
+ model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
+ model_input = model_input.to(dtype=weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ if not args.do_edm_style_training:
+ if args.loss_type == "huber" or args.loss_type == "smooth_l1":
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device="cpu")
+ timestep = timesteps.item()
+
+ if args.huber_schedule == "exponential":
+ alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
+ huber_c = math.exp(-alpha * timestep)
+ elif args.huber_schedule == "snr":
+ alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
+ sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
+ elif args.huber_schedule == "constant":
+ huber_c = args.huber_c
+ else:
+ raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
+
+ timesteps = timesteps.repeat(bsz).to(model_input.device)
+ elif args.loss_type == "l2":
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ huber_c = 1 # may be anything, as it's not used
+ else:
+ raise NotImplementedError(f"Unknown loss type {args.loss_type}")
+ timesteps = timesteps.long()
+ else:
+ if "huber" in args.loss_type or "l1" in args.loss_type:
+ raise NotImplementedError("Huber loss is not implemented for EDM training yet!")
+ # in EDM formulation, the model is conditioned on the pre-conditioned noise levels
+ # instead of discrete timesteps, so here we sample indices to get the noise levels
+ # from `scheduler.timesteps`
+ indices = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,))
+ timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+ # For EDM-style training, we first obtain the sigmas based on the continuous timesteps.
+ # We then precondition the final model inputs based on these sigmas instead of the timesteps.
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ if args.do_edm_style_training:
+ sigmas = get_sigmas(timesteps, len(noisy_model_input.shape), noisy_model_input.dtype)
+ if "EDM" in scheduler_type:
+ inp_noisy_latents = noise_scheduler.precondition_inputs(noisy_model_input, sigmas)
+ else:
+ inp_noisy_latents = noisy_model_input / ((sigmas**2 + 1) ** 0.5)
+
+ # time ids
+ add_time_ids = torch.cat(
+ [
+ compute_time_ids(original_size=s, crops_coords_top_left=c)
+ for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
+ ]
+ )
+
+ # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
+ if not train_dataset.custom_instance_prompts:
+ elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
+ else:
+ elems_to_repeat_text_embeds = 1
+
+ # Predict the noise residual
+ if not args.train_text_encoder:
+ unet_added_conditions = {
+ "time_ids": add_time_ids,
+ "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
+ }
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ model_pred = unet(
+ inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
+ timesteps,
+ prompt_embeds_input,
+ added_cond_kwargs=unet_added_conditions,
+ return_dict=False,
+ )[0]
+ else:
+ unet_added_conditions = {"time_ids": add_time_ids}
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=None,
+ prompt=None,
+ text_input_ids_list=[tokens_one, tokens_two],
+ )
+ unet_added_conditions.update(
+ {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
+ )
+ prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
+ model_pred = unet(
+ inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
+ timesteps,
+ prompt_embeds_input,
+ added_cond_kwargs=unet_added_conditions,
+ return_dict=False,
+ )[0]
+
+ weighting = None
+ if args.do_edm_style_training:
+ # Similar to the input preconditioning, the model predictions are also preconditioned
+ # on noised model inputs (before preconditioning) and the sigmas.
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ if "EDM" in scheduler_type:
+ model_pred = noise_scheduler.precondition_outputs(noisy_model_input, model_pred, sigmas)
+ else:
+ if noise_scheduler.config.prediction_type == "epsilon":
+ model_pred = model_pred * (-sigmas) + noisy_model_input
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ model_pred = model_pred * (-sigmas / (sigmas**2 + 1) ** 0.5) + (
+ noisy_model_input / (sigmas**2 + 1)
+ )
+ # We are not doing weighting here because it tends result in numerical problems.
+ # See: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
+ # There might be other alternatives for weighting as well:
+ # https://github.com/huggingface/diffusers/pull/7126#discussion_r1505404686
+ if "EDM" not in scheduler_type:
+ weighting = (sigmas**-2.0).float()
+
+ # Get the target for loss depending on the prediction type
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = model_input if args.do_edm_style_training else noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = (
+ model_input
+ if args.do_edm_style_training
+ else noise_scheduler.get_velocity(model_input, noise, timesteps)
+ )
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = conditional_loss(
+ model_pred_prior,
+ target_prior,
+ reduction="mean",
+ loss_type=args.loss_type,
+ huber_c=huber_c,
+ weighting=weighting,
+ )
+
+ if args.snr_gamma is None:
+ loss = conditional_loss(
+ model_pred,
+ target,
+ reduction="mean",
+ loss_type=args.loss_type,
+ huber_c=huber_c,
+ weighting=weighting,
+ )
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ base_weight = (
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
+ )
+
+ if noise_scheduler.config.prediction_type == "v_prediction":
+ # Velocity objective needs to be floored to an SNR weight of one.
+ mse_loss_weights = base_weight + 1
+ else:
+ # Epsilon and sample both use the same loss weights.
+ mse_loss_weights = base_weight
+
+ loss = conditional_loss(
+ model_pred, target, reduction="none", loss_type=args.loss_type, huber_c=huber_c, weighting=None
+ )
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
+ if args.train_text_encoder
+ else unet_lora_parameters
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ if not args.train_text_encoder:
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="text_encoder_2",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ )
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ unet = unet.to(torch.float32)
+ unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+
+ if args.train_text_encoder:
+ text_encoder_one = unwrap_model(text_encoder_one)
+ text_encoder_lora_layers = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(text_encoder_one.to(torch.float32))
+ )
+ text_encoder_two = unwrap_model(text_encoder_two)
+ text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(text_encoder_two.to(torch.float32))
+ )
+ else:
+ text_encoder_lora_layers = None
+ text_encoder_2_lora_layers = None
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_layers,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
+ )
+ if args.output_kohya_format:
+ lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
+ peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
+ kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
+ save_file(kohya_state_dict, f"{args.output_dir}/pytorch_lora_weights_kohya.safetensors")
+
+ # Final inference
+ # Load previous pipeline
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ use_dora=args.use_dora,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py
new file mode 100644
index 000000000000..0f4cc6c50b5e
--- /dev/null
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py
@@ -0,0 +1,1162 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.state import AcceleratorState
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+from transformers.utils import ContextManagers
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.28.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ args,
+ repo_id: str,
+ images: list = None,
+ repo_folder: str = None,
+):
+ img_str = ""
+ if len(images) > 0:
+ image_grid = make_image_grid(images, 1, len(args.validation_prompts))
+ image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
+ img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
+
+ model_description = f"""
+# Text-to-image finetuning - {repo_id}
+
+This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
+{img_str}
+
+## Pipeline usage
+
+You can use the pipeline like so:
+
+```python
+from diffusers import DiffusionPipeline
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
+prompt = "{args.validation_prompts[0]}"
+image = pipeline(prompt).images[0]
+image.save("my_image.png")
+```
+
+## Training info
+
+These are the key hyperparameters used during training:
+
+* Epochs: {args.num_train_epochs}
+* Learning rate: {args.learning_rate}
+* Batch size: {args.train_batch_size}
+* Gradient accumulation steps: {args.gradient_accumulation_steps}
+* Image resolution: {args.resolution}
+* Mixed-precision: {args.mixed_precision}
+
+"""
+ wandb_info = ""
+ if is_wandb_available():
+ wandb_run_url = None
+ if wandb.run is not None:
+ wandb_run_url = wandb.run.url
+
+ if wandb_run_url is not None:
+ wandb_info = f"""
+More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
+"""
+
+ model_description += wandb_info
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=args.pretrained_model_name_or_path,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
+ logger.info("Running validation... ")
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=accelerator.unwrap_model(vae),
+ text_encoder=accelerator.unwrap_model(text_encoder),
+ tokenizer=tokenizer,
+ unet=accelerator.unwrap_model(unet),
+ safety_checker=None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ images = []
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompts",
+ type=str,
+ default=None,
+ nargs="+",
+ help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--non_ema_revision",
+ type=str,
+ default=None,
+ required=False,
+ help=(
+ "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
+ " remote repository specified with --pretrained_model_name_or_path."
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber", "smooth_l1"],
+ help="The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.",
+ )
+ parser.add_argument(
+ "--huber_schedule",
+ type=str,
+ default="snr",
+ choices=["constant", "exponential", "snr"],
+ help="The schedule to use for the huber losses parameter",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.1,
+ help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=5,
+ help="Run validation every X epochs.",
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="text2image-fine-tune",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ # default to using the same revision for the non-ema model if not specified
+ if args.non_ema_revision is None:
+ args.non_ema_revision = args.revision
+
+ return args
+
+
+# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
+def conditional_loss(
+ model_pred: torch.Tensor,
+ target: torch.Tensor,
+ reduction: str = "mean",
+ loss_type: str = "l2",
+ huber_c: float = 0.1,
+):
+ if loss_type == "l2":
+ loss = F.mse_loss(model_pred, target, reduction=reduction)
+ elif loss_type == "huber":
+ loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ elif loss_type == "smooth_l1":
+ loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
+ return loss
+
+
+def main():
+ args = parse_args()
+
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if args.non_ema_revision is not None:
+ deprecate(
+ "non_ema_revision!=None",
+ "0.15.0",
+ message=(
+ "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
+ " use `--variant=non_ema` instead."
+ ),
+ )
+ logging_dir = os.path.join(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+
+ def deepspeed_zero_init_disabled_context_manager():
+ """
+ returns either a context list that includes one that will disable zero.Init or an empty context list
+ """
+ deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
+ if deepspeed_plugin is None:
+ return []
+
+ return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
+
+ # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
+ # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
+ # will try to assign the same optimizer with the same weights to all models during
+ # `deepspeed.initialize`, which of course doesn't work.
+ #
+ # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
+ # frozen models from being partitioned during `zero.Init` which gets called during
+ # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
+ # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
+ with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
+ )
+
+ # Freeze vae and text_encoder and set unet to trainable
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ unet.train()
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ unet.parameters(),
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ args.mixed_precision = accelerator.mixed_precision
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+ args.mixed_precision = accelerator.mixed_precision
+
+ # Move text_encode and vae to gpu and cast to weight_dtype
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+ tracker_config.pop("validation_prompts")
+ accelerator.init_trackers(args.tracker_project_name, tracker_config)
+
+ # Function for unwrapping if model was compiled with `torch.compile`.
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
+ )
+ if args.input_perturbation:
+ new_noise = noise + args.input_perturbation * torch.randn_like(noise)
+ bsz = latents.shape[0]
+
+ # Sample a random timestep for each image
+ if args.loss_type == "huber" or args.loss_type == "smooth_l1":
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device="cpu")
+ timestep = timesteps.item()
+
+ if args.huber_schedule == "exponential":
+ alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
+ huber_c = math.exp(-alpha * timestep)
+ elif args.huber_schedule == "snr":
+ alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
+ sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
+ elif args.huber_schedule == "constant":
+ huber_c = args.huber_c
+ else:
+ raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
+
+ timesteps = timesteps.repeat(bsz).to(latents.device)
+ elif args.loss_type == "l2":
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
+ )
+ huber_c = 1 # may be anything, as it's not used
+ else:
+ raise NotImplementedError(f"Unknown loss type {args.loss_type}")
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ if args.input_perturbation:
+ noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
+ else:
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
+
+ if args.snr_gamma is None:
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
+ )
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
+ )
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+ log_validation(
+ vae,
+ text_encoder,
+ tokenizer,
+ unet,
+ args,
+ accelerator,
+ weight_dtype,
+ global_step,
+ )
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
+ # Create the pipeline using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ text_encoder=text_encoder,
+ vae=vae,
+ unet=unet,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.save_pretrained(args.output_dir)
+
+ # Run a final round of inference.
+ images = []
+ if args.validation_prompts is not None:
+ logger.info("Running inference for collecting generated images...")
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.torch_dtype = weight_dtype
+ pipeline.set_progress_bar_config(disable=True)
+
+ if args.enable_xformers_memory_efficient_attention:
+ pipeline.enable_xformers_memory_efficient_attention()
+
+ if args.seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+
+ for i in range(len(args.validation_prompts)):
+ with torch.autocast("cuda"):
+ image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
+ images.append(image)
+
+ if args.push_to_hub:
+ save_model_card(args, repo_id, images, repo_folder=args.output_dir)
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py
new file mode 100644
index 000000000000..f22519b02e2b
--- /dev/null
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py
@@ -0,0 +1,1051 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig
+from peft.utils import get_peft_model_state_dict
+from torchvision import transforms
+from tqdm.auto import tqdm
+from transformers import CLIPTextModel, CLIPTokenizer
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import cast_training_params, compute_snr
+from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.28.0.dev0")
+
+logger = get_logger(__name__, log_level="INFO")
+
+
+def save_model_card(
+ repo_id: str,
+ images: list = None,
+ base_model: str = None,
+ dataset_name: str = None,
+ repo_folder: str = None,
+):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ model_description = f"""
+# LoRA text2image fine-tuning - {repo_id}
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion",
+ "stable-diffusion-diffusers",
+ "text-to-image",
+ "diffusers",
+ "diffusers-training",
+ "lora",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber", "smooth_l1"],
+ help="The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.",
+ )
+ parser.add_argument(
+ "--huber_schedule",
+ type=str,
+ default="snr",
+ choices=["constant", "exponential", "snr"],
+ help="The schedule to use for the huber losses parameter",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.1,
+ help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
+def conditional_loss(
+ model_pred: torch.Tensor,
+ target: torch.Tensor,
+ reduction: str = "mean",
+ loss_type: str = "l2",
+ huber_c: float = 0.1,
+):
+ if loss_type == "l2":
+ loss = F.mse_loss(model_pred, target, reduction=reduction)
+ elif loss_type == "huber":
+ loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ elif loss_type == "smooth_l1":
+ loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
+ return loss
+
+
+def main():
+ args = parse_args()
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+ # Load scheduler, tokenizer and models.
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ tokenizer = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
+ )
+ text_encoder = CLIPTextModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
+ )
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+ # freeze parameters of models to save more memory
+ unet.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Freeze the unet parameters before adding adapters
+ for param in unet.parameters():
+ param.requires_grad_(False)
+
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # Add adapter and make sure the trainable params are in float32.
+ unet.add_adapter(unet_lora_config)
+ if args.mixed_precision == "fp16":
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(unet, dtype=torch.float32)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Initialize the optimizer
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
+ )
+
+ optimizer_cls = bnb.optim.AdamW8bit
+ else:
+ optimizer_cls = torch.optim.AdamW
+
+ optimizer = optimizer_cls(
+ lora_layers,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ data_dir=args.train_data_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ inputs = tokenizer(
+ captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
+ )
+ return inputs.input_ids
+
+ # Preprocessing the datasets.
+ train_transforms = transforms.Compose(
+ [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ examples["pixel_values"] = [train_transforms(image) for image in images]
+ examples["input_ids"] = tokenize_captions(examples)
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ input_ids = torch.stack([example["input_ids"] for example in examples])
+ return {"pixel_values": pixel_values, "input_ids": input_ids}
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
+ )
+
+ bsz = latents.shape[0]
+ # Sample a random timestep for each image
+ if args.loss_type == "huber" or args.loss_type == "smooth_l1":
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device="cpu")
+ timestep = timesteps.item()
+
+ if args.huber_schedule == "exponential":
+ alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
+ huber_c = math.exp(-alpha * timestep)
+ elif args.huber_schedule == "snr":
+ alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
+ sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
+ elif args.huber_schedule == "constant":
+ huber_c = args.huber_c
+ else:
+ raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
+
+ timesteps = timesteps.repeat(bsz).to(latents.device)
+ elif args.loss_type == "l2":
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
+ )
+ huber_c = 1 # may be anything, as it's not used
+ else:
+ raise NotImplementedError(f"Unknown loss type {args.loss_type}")
+
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Get the text embedding for conditioning
+ encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ # Predict the noise residual and compute loss
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
+
+ if args.snr_gamma is None:
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
+ )
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
+ )
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = lora_layers
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if global_step % args.checkpointing_steps == 0:
+ if accelerator.is_main_process:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+
+ unwrapped_unet = unwrap_model(unet)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(unwrapped_unet)
+ )
+
+ StableDiffusionPipeline.save_lora_weights(
+ save_directory=save_path,
+ unet_lora_layers=unet_lora_state_dict,
+ safe_serialization=True,
+ )
+
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device)
+ if args.seed is not None:
+ generator = generator.manual_seed(args.seed)
+ images = []
+ with torch.cuda.amp.autocast():
+ for _ in range(args.num_validation_images):
+ images.append(
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unet.to(torch.float32)
+
+ unwrapped_unet = unwrap_model(unet)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
+ StableDiffusionPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_state_dict,
+ safe_serialization=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ # Final inference
+ # Load previous pipeline
+ if args.validation_prompt is not None:
+ pipeline = DiffusionPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device)
+ if args.seed is not None:
+ generator = generator.manual_seed(args.seed)
+ images = []
+ with torch.cuda.amp.autocast():
+ for _ in range(args.num_validation_images):
+ images.append(
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if len(images) != 0:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
new file mode 100644
index 000000000000..e5ff9d39e8ba
--- /dev/null
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
@@ -0,0 +1,1384 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion XL for text2image with support for LoRA."""
+
+import argparse
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.loaders import LoraLoaderMixin
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
+from diffusers.utils import (
+ check_min_version,
+ convert_state_dict_to_diffusers,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.28.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images: list = None,
+ base_model: str = None,
+ dataset_name: str = None,
+ train_text_encoder: bool = False,
+ repo_folder: str = None,
+ vae_path: str = None,
+):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ model_description = f"""
+# LoRA text2image fine-tuning - {repo_id}
+
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+
+LoRA for the text encoder was enabled: {train_text_encoder}.
+
+Special VAE used for training: {vae_path}.
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion-xl",
+ "stable-diffusion-xl-diffusers",
+ "text-to-image",
+ "diffusers",
+ "diffusers-training",
+ "lora",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd-model-finetuned-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber", "smooth_l1"],
+ help="The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.",
+ )
+ parser.add_argument(
+ "--huber_schedule",
+ type=str,
+ default="snr",
+ choices=["constant", "exponential", "snr"],
+ help="The schedule to use for the huber losses parameter",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.1,
+ help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--debug_loss",
+ action="store_true",
+ help="debug loss for each image, if filenames are awailable in the dataset",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ return args
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def tokenize_prompt(tokenizer, prompt):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
+ prompt_embeds_list = []
+
+ for i, text_encoder in enumerate(text_encoders):
+ if tokenizers is not None:
+ tokenizer = tokenizers[i]
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
+ else:
+ assert text_input_ids_list is not None
+ text_input_ids = text_input_ids_list[i]
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds[-1][-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
+def conditional_loss(
+ model_pred: torch.Tensor,
+ target: torch.Tensor,
+ reduction: str = "mean",
+ loss_type: str = "l2",
+ huber_c: float = 0.1,
+):
+ if loss_type == "l2":
+ loss = F.mse_loss(model_pred, target, reduction=reduction)
+ elif loss_type == "huber" or loss_type == "huber_scheduled":
+ loss = huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
+ return loss
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ unet.to(accelerator.device, dtype=weight_dtype)
+
+ if args.pretrained_vae_model_name_or_path is None:
+ vae.to(accelerator.device, dtype=torch.float32)
+ else:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # now we will add new LoRA weights to the attention layers
+ # Set correct lora layers
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+
+ unet.add_adapter(unet_lora_config)
+
+ # The text encoder comes from 🤗 transformers, we will also attach adapters to it.
+ if args.train_text_encoder:
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder_one.add_adapter(text_lora_config)
+ text_encoder_two.add_adapter(text_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ # there are only two options here. Either are just the unet attn processor layers
+ # or there are the unet and text encoder attn layers
+ unet_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+ text_encoder_two_lora_layers_to_save = None
+
+ for model in models:
+ if isinstance(unwrap_model(model), type(unwrap_model(unet))):
+ unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
+ text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
+ text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(model)
+ )
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ if weights:
+ weights.pop()
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ output_dir,
+ unet_lora_layers=unet_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ unet_ = None
+ text_encoder_one_ = None
+ text_encoder_two_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_two))):
+ text_encoder_two_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict, _ = LoraLoaderMixin.lora_state_dict(input_dir)
+ unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ if args.train_text_encoder:
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
+
+ _set_state_dict_into_text_encoder(
+ lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [unet_]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one_, text_encoder_two_])
+ cast_training_params(models, dtype=torch.float32)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+ text_encoder_two.gradient_checkpointing_enable()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [unet]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one, text_encoder_two])
+ cast_training_params(models, dtype=torch.float32)
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = list(filter(lambda p: p.requires_grad, unet.parameters()))
+ if args.train_text_encoder:
+ params_to_optimize = (
+ params_to_optimize
+ + list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
+ + list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
+ )
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize input captions and transform the images.
+ def tokenize_captions(examples, is_train=True):
+ captions = []
+ for caption in examples[caption_column]:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+ else:
+ raise ValueError(
+ f"Caption column `{caption_column}` should contain either strings or lists of strings."
+ )
+ tokens_one = tokenize_prompt(tokenizer_one, captions)
+ tokens_two = tokenize_prompt(tokenizer_two, captions)
+ return tokens_one, tokens_two
+
+ # Preprocessing the datasets.
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ # image aug
+ original_sizes = []
+ all_images = []
+ crop_top_lefts = []
+ for image in images:
+ original_sizes.append((image.height, image.width))
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ crop_top_left = (y1, x1)
+ crop_top_lefts.append(crop_top_left)
+ image = train_transforms(image)
+ all_images.append(image)
+
+ examples["original_sizes"] = original_sizes
+ examples["crop_top_lefts"] = crop_top_lefts
+ examples["pixel_values"] = all_images
+ tokens_one, tokens_two = tokenize_captions(examples)
+ examples["input_ids_one"] = tokens_one
+ examples["input_ids_two"] = tokens_two
+ if args.debug_loss:
+ fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename]
+ if fnames:
+ examples["filenames"] = fnames
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)
+
+ def collate_fn(examples):
+ pixel_values = torch.stack([example["pixel_values"] for example in examples])
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ original_sizes = [example["original_sizes"] for example in examples]
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
+ input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
+ input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
+ result = {
+ "pixel_values": pixel_values,
+ "input_ids_one": input_ids_one,
+ "input_ids_two": input_ids_two,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+
+ filenames = [example["filenames"] for example in examples if "filenames" in example]
+ if filenames:
+ result["filenames"] = filenames
+ return result
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
+ )
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ unet.train()
+ if args.train_text_encoder:
+ text_encoder_one.train()
+ text_encoder_two.train()
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Convert images to latent space
+ if args.pretrained_vae_model_name_or_path is not None:
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
+ else:
+ pixel_values = batch["pixel_values"]
+
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
+ )
+
+ bsz = model_input.shape[0]
+ # Sample a random timestep for each image
+ if args.loss_type == "huber" or args.loss_type == "smooth_l1":
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device="cpu")
+ timestep = timesteps.item()
+
+ if args.huber_schedule == "exponential":
+ alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
+ huber_c = math.exp(-alpha * timestep)
+ elif args.huber_schedule == "snr":
+ alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
+ sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
+ elif args.huber_schedule == "constant":
+ huber_c = args.huber_c
+ else:
+ raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
+
+ timesteps = timesteps.repeat(bsz).to(model_input.device)
+ elif args.loss_type == "l2":
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ huber_c = 1 # may be anything, as it's not used
+ else:
+ raise NotImplementedError(f"Unknown loss type {args.loss_type}")
+
+ timesteps = timesteps.long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # time ids
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ )
+
+ # Predict the noise residual
+ unet_added_conditions = {"time_ids": add_time_ids}
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=None,
+ prompt=None,
+ text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"]],
+ )
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ prompt_embeds,
+ added_cond_kwargs=unet_added_conditions,
+ return_dict=False,
+ )[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.snr_gamma is None:
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
+ )
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
+ )
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+ if args.debug_loss and "filenames" in batch:
+ for fname in batch["filenames"]:
+ accelerator.log({"loss_for_" + fname: loss}, step=global_step)
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ # create pipeline
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=unwrap_model(text_encoder_one),
+ text_encoder_2=unwrap_model(text_encoder_two),
+ unet=unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ with torch.cuda.amp.autocast():
+ images = [
+ pipeline(**pipeline_args, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+
+ if args.train_text_encoder:
+ text_encoder_one = unwrap_model(text_encoder_one)
+ text_encoder_two = unwrap_model(text_encoder_two)
+
+ text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
+ text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
+ else:
+ text_encoder_lora_layers = None
+ text_encoder_2_lora_layers = None
+
+ StableDiffusionXLPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_state_dict,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
+ )
+
+ del unet
+ del text_encoder_one
+ del text_encoder_two
+ del text_encoder_lora_layers
+ del text_encoder_2_lora_layers
+ torch.cuda.empty_cache()
+
+ # Final inference
+ # Make sure vae.dtype is consistent with the unet.dtype
+ if args.mixed_precision == "fp16":
+ vae.to(weight_dtype)
+ # Load previous pipeline
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = pipeline.to(accelerator.device)
+
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ train_text_encoder=args.train_text_encoder,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py
new file mode 100644
index 000000000000..1dac573fce4c
--- /dev/null
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py
@@ -0,0 +1,1394 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Fine-tuning script for Stable Diffusion XL for text2image."""
+
+import argparse
+import functools
+import gc
+import logging
+import math
+import os
+import random
+import shutil
+from pathlib import Path
+
+import accelerate
+import datasets
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import concatenate_datasets, load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer, PretrainedConfig
+
+import diffusers
+from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import EMAModel, compute_snr
+from diffusers.utils import check_min_version, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.28.0.dev0")
+
+logger = get_logger(__name__)
+
+
+DATASET_NAME_MAPPING = {
+ "lambdalabs/pokemon-blip-captions": ("image", "text"),
+}
+
+
+def save_model_card(
+ repo_id: str,
+ images: list = None,
+ validation_prompt: str = None,
+ base_model: str = None,
+ dataset_name: str = None,
+ repo_folder: str = None,
+ vae_path: str = None,
+):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ model_description = f"""
+# Text-to-image finetuning - {repo_id}
+
+This pipeline was finetuned from **{base_model}** on the **{dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
+{img_str}
+
+Special VAE used for training: {vae_path}.
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion-xl",
+ "stable-diffusion-xl-diffusers",
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=1,
+ help=(
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--proportion_empty_prompts",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sdxl-model-finetuned",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--timestep_bias_strategy",
+ type=str,
+ default="none",
+ choices=["earlier", "later", "range", "none"],
+ help=(
+ "The timestep bias strategy, which may help direct the model toward learning low or high frequency details."
+ " Choices: ['earlier', 'later', 'range', 'none']."
+ " The default is 'none', which means no bias is applied, and training proceeds normally."
+ " The value of 'later' will increase the frequency of the model's final training timesteps."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_multiplier",
+ type=float,
+ default=1.0,
+ help=(
+ "The multiplier for the bias. Defaults to 1.0, which means no bias is applied."
+ " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_begin",
+ type=int,
+ default=0,
+ help=(
+ "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias."
+ " Defaults to zero, which equates to having no specific bias."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_end",
+ type=int,
+ default=1000,
+ help=(
+ "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias."
+ " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_bias_portion",
+ type=float,
+ default=0.25,
+ help=(
+ "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased."
+ " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines"
+ " whether the biased portions are in the earlier or later timesteps."
+ ),
+ )
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--prediction_type",
+ type=str,
+ default=None,
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
+ )
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber", "smooth_l1"],
+ help="The type of loss to use and whether it's timestep-scheduled. See Issue #7488 for more info.",
+ )
+ parser.add_argument(
+ "--huber_schedule",
+ type=str,
+ default="snr",
+ choices=["constant", "exponential", "snr"],
+ help="The schedule to use for the huber losses parameter",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.1,
+ help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ # Sanity checks
+ if args.dataset_name is None and args.train_data_dir is None:
+ raise ValueError("Need either a dataset name or a training folder.")
+
+ if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
+ raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
+
+ return args
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):
+ prompt_embeds_list = []
+ prompt_batch = batch[caption_column]
+
+ captions = []
+ for caption in prompt_batch:
+ if random.random() < proportion_empty_prompts:
+ captions.append("")
+ elif isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ return_dict=False,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds[-1][-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}
+
+
+def compute_vae_encodings(batch, vae):
+ images = batch.pop("pixel_values")
+ pixel_values = torch.stack(list(images))
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
+
+ with torch.no_grad():
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ return {"model_input": model_input.cpu()}
+
+
+def generate_timestep_weights(args, num_timesteps):
+ weights = torch.ones(num_timesteps)
+
+ # Determine the indices to bias
+ num_to_bias = int(args.timestep_bias_portion * num_timesteps)
+
+ if args.timestep_bias_strategy == "later":
+ bias_indices = slice(-num_to_bias, None)
+ elif args.timestep_bias_strategy == "earlier":
+ bias_indices = slice(0, num_to_bias)
+ elif args.timestep_bias_strategy == "range":
+ # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.
+ range_begin = args.timestep_bias_begin
+ range_end = args.timestep_bias_end
+ if range_begin < 0:
+ raise ValueError(
+ "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero."
+ )
+ if range_end > num_timesteps:
+ raise ValueError(
+ "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps."
+ )
+ bias_indices = slice(range_begin, range_end)
+ else: # 'none' or any other string
+ return weights
+ if args.timestep_bias_multiplier <= 0:
+ return ValueError(
+ "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps."
+ " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead."
+ " A timestep bias multiplier less than or equal to 0 is not allowed."
+ )
+
+ # Apply the bias
+ weights[bias_indices] *= args.timestep_bias_multiplier
+
+ # Normalize
+ weights /= weights.sum()
+
+ return weights
+
+
+# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already
+def conditional_loss(
+ model_pred: torch.Tensor,
+ target: torch.Tensor,
+ reduction: str = "mean",
+ loss_type: str = "l2",
+ huber_c: float = 0.1,
+):
+ if loss_type == "l2":
+ loss = F.mse_loss(model_pred, target, reduction=reduction)
+ elif loss_type == "huber":
+ loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ elif loss_type == "smooth_l1":
+ loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
+ if reduction == "mean":
+ loss = torch.mean(loss)
+ elif reduction == "sum":
+ loss = torch.sum(loss)
+ else:
+ raise NotImplementedError(f"Unsupported Loss Type {loss_type}")
+ return loss
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+ import wandb
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ datasets.utils.logging.set_verbosity_warning()
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ datasets.utils.logging.set_verbosity_error()
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ # Check for terminal SNR in combination with SNR Gamma
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+
+ # Freeze vae and text encoders.
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ # Set unet as trainable.
+ unet.train()
+
+ # For mixed precision training we cast all non-trainable weights to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ # Create EMA for the unet.
+ if args.use_ema:
+ ema_unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
+ )
+ ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ if args.use_ema:
+ ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
+
+ for i, model in enumerate(models):
+ model.save_pretrained(os.path.join(output_dir, "unet"))
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+ if args.use_ema:
+ load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
+ ema_unet.load_state_dict(load_model.state_dict())
+ ema_unet.to(accelerator.device)
+ del load_model
+
+ for _ in range(len(models)):
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ # load diffusers style into model
+ load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
+ model.register_to_config(**load_model.config)
+
+ model.load_state_dict(load_model.state_dict())
+ del load_model
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ params_to_optimize = unet.parameters()
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ if args.dataset_name is not None:
+ # Downloading and loading a dataset from the hub.
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ else:
+ data_files = {}
+ if args.train_data_dir is not None:
+ data_files["train"] = os.path.join(args.train_data_dir, "**")
+ dataset = load_dataset(
+ "imagefolder",
+ data_files=data_files,
+ cache_dir=args.cache_dir,
+ )
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
+ if args.image_column is None:
+ image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
+ )
+ if args.caption_column is None:
+ caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ raise ValueError(
+ f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
+ )
+
+ # Preprocessing the datasets.
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
+
+ def preprocess_train(examples):
+ images = [image.convert("RGB") for image in examples[image_column]]
+ # image aug
+ original_sizes = []
+ all_images = []
+ crop_top_lefts = []
+ for image in images:
+ original_sizes.append((image.height, image.width))
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ crop_top_left = (y1, x1)
+ crop_top_lefts.append(crop_top_left)
+ image = train_transforms(image)
+ all_images.append(image)
+
+ examples["original_sizes"] = original_sizes
+ examples["crop_top_lefts"] = crop_top_lefts
+ examples["pixel_values"] = all_images
+ return examples
+
+ with accelerator.main_process_first():
+ if args.max_train_samples is not None:
+ dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
+ # Set the training transforms
+ train_dataset = dataset["train"].with_transform(preprocess_train)
+
+ # Let's first compute all the embeddings so that we can free up the text encoders
+ # from memory. We will pre-compute the VAE encodings too.
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+ compute_embeddings_fn = functools.partial(
+ encode_prompt,
+ text_encoders=text_encoders,
+ tokenizers=tokenizers,
+ proportion_empty_prompts=args.proportion_empty_prompts,
+ caption_column=args.caption_column,
+ )
+ compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
+ with accelerator.main_process_first():
+ from datasets.fingerprint import Hasher
+
+ # fingerprint used by the cache for the other processes to load the result
+ # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
+ new_fingerprint = Hasher.hash(args)
+ new_fingerprint_for_vae = Hasher.hash(vae_path)
+ train_dataset_with_embeddings = train_dataset.map(
+ compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
+ )
+ train_dataset_with_vae = train_dataset.map(
+ compute_vae_encodings_fn,
+ batched=True,
+ batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,
+ new_fingerprint=new_fingerprint_for_vae,
+ )
+ precomputed_dataset = concatenate_datasets(
+ [train_dataset_with_embeddings, train_dataset_with_vae.remove_columns(["image", "text"])], axis=1
+ )
+ precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)
+
+ del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
+ del text_encoders, tokenizers, vae
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def collate_fn(examples):
+ model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
+ original_sizes = [example["original_sizes"] for example in examples]
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
+ prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
+ pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
+
+ return {
+ "model_input": model_input,
+ "prompt_embeds": prompt_embeds,
+ "pooled_prompt_embeds": pooled_prompt_embeds,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ precomputed_dataset,
+ shuffle=True,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ if args.use_ema:
+ ema_unet.to(accelerator.device)
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args))
+
+ # Function for unwrapping if torch.compile() was used in accelerate.
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(precomputed_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ train_loss = 0.0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ # Sample noise that we'll add to the latents
+ model_input = batch["model_input"].to(accelerator.device)
+ noise = torch.randn_like(model_input)
+ if args.noise_offset:
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
+ noise += args.noise_offset * torch.randn(
+ (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
+ )
+
+ bsz = model_input.shape[0]
+ if args.timestep_bias_strategy == "none":
+ # Sample a random timestep for each image
+ if args.loss_type == "huber" or args.loss_type == "smooth_l1":
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (1,), device="cpu")
+ timestep = timesteps.item()
+
+ if args.huber_schedule == "exponential":
+ alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
+ huber_c = math.exp(-alpha * timestep)
+ elif args.huber_schedule == "snr":
+ alphas_cumprod = noise_scheduler.alphas_cumprod[timestep]
+ sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
+ elif args.huber_schedule == "constant":
+ huber_c = args.huber_c
+ else:
+ raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!")
+
+ timesteps = timesteps.repeat(bsz).to(model_input.device)
+ elif args.loss_type == "l2":
+ timesteps = torch.randint(
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
+ )
+ huber_c = 1 # may be anything, as it's not used
+ else:
+ raise NotImplementedError(f"Unknown loss type {args.loss_type}")
+
+ timesteps = timesteps.long()
+
+ else:
+ if "huber_scheduled" in args.loss_type:
+ raise NotImplementedError(
+ "Randomly weighted timesteps not implemented yet for scheduled huber loss!"
+ )
+ else:
+ huber_c = args.huber_c
+ # Sample a random timestep for each image, potentially biased by the timestep weights.
+ # Biasing the timestep weights allows us to spend less time training irrelevant timesteps.
+ weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(
+ model_input.device
+ )
+ timesteps = torch.multinomial(weights, bsz, replacement=True).long()
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+
+ # time ids
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ )
+
+ # Predict the noise residual
+ unet_added_conditions = {"time_ids": add_time_ids}
+ prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ prompt_embeds,
+ added_cond_kwargs=unet_added_conditions,
+ return_dict=False,
+ )[0]
+
+ # Get the target for loss depending on the prediction type
+ if args.prediction_type is not None:
+ # set prediction_type of scheduler if defined
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
+
+ if noise_scheduler.config.prediction_type == "epsilon":
+ target = noise
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
+ elif noise_scheduler.config.prediction_type == "sample":
+ # We set the target to latents here, but the model_pred will return the noise sample prediction.
+ target = model_input
+ # We will have to subtract the noise residual from the prediction to get the target sample.
+ model_pred = model_pred - noise
+ else:
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
+
+ if args.snr_gamma is None:
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
+ )
+ else:
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
+ # This is discussed in Section 4.2 of the same paper.
+ snr = compute_snr(noise_scheduler, timesteps)
+ mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
+ dim=1
+ )[0]
+ if noise_scheduler.config.prediction_type == "epsilon":
+ mse_loss_weights = mse_loss_weights / snr
+ elif noise_scheduler.config.prediction_type == "v_prediction":
+ mse_loss_weights = mse_loss_weights / (snr + 1)
+
+ loss = conditional_loss(
+ model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
+ )
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
+ loss = loss.mean()
+
+ # Gather the losses across all processes for logging (if we use distributed training).
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
+
+ # Backpropagate
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = unet.parameters()
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ if args.use_ema:
+ ema_unet.step(unet.parameters())
+ progress_bar.update(1)
+ global_step += 1
+ accelerator.log({"train_loss": train_loss}, step=global_step)
+ train_loss = 0.0
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ if args.use_ema:
+ # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
+ ema_unet.store(unet.parameters())
+ ema_unet.copy_to(unet.parameters())
+
+ # create pipeline
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ unet=accelerator.unwrap_model(unet),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ if args.prediction_type is not None:
+ scheduler_args = {"prediction_type": args.prediction_type}
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ pipeline_args = {"prompt": args.validation_prompt}
+
+ with torch.cuda.amp.autocast():
+ images = [
+ pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "validation": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ torch.cuda.empty_cache()
+
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = unwrap_model(unet)
+ if args.use_ema:
+ ema_unet.copy_to(unet.parameters())
+
+ # Serialize pipeline.
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unet,
+ vae=vae,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ if args.prediction_type is not None:
+ scheduler_args = {"prediction_type": args.prediction_type}
+ pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
+ pipeline.save_pretrained(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline = pipeline.to(accelerator.device)
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ with torch.cuda.amp.autocast():
+ images = [
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
+ for _ in range(args.num_validation_images)
+ ]
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ "test": [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
+ for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id=repo_id,
+ images=images,
+ validation_prompt=args.validation_prompt,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ vae_path=args.pretrained_vae_model_name_or_path,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
index 973915c96da4..50735ef044a6 100644
--- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py
+++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
@@ -799,6 +799,10 @@ def main(args):
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py
index 4138d1b46329..84f4c6514cfd 100644
--- a/examples/text_to_image/train_text_to_image.py
+++ b/examples/text_to_image/train_text_to_image.py
@@ -20,6 +20,7 @@
import os
import random
import shutil
+from contextlib import nullcontext
from pathlib import Path
import accelerate
@@ -164,7 +165,12 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
images = []
for i in range(len(args.validation_prompts)):
- with torch.autocast("cuda"):
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
images.append(image)
@@ -523,6 +529,10 @@ def main():
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py
index 71b99f1588c3..7164ac909cb2 100644
--- a/examples/text_to_image/train_text_to_image_lora.py
+++ b/examples/text_to_image/train_text_to_image_lora.py
@@ -21,6 +21,7 @@
import os
import random
import shutil
+from contextlib import nullcontext
from pathlib import Path
import datasets
@@ -408,6 +409,11 @@ def main():
log_with=args.report_to,
project_config=accelerator_project_config,
)
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
@@ -878,7 +884,12 @@ def collate_fn(examples):
if args.seed is not None:
generator = generator.manual_seed(args.seed)
images = []
- with torch.cuda.amp.autocast():
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
for _ in range(args.num_validation_images):
images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
@@ -948,7 +959,12 @@ def collate_fn(examples):
if args.seed is not None:
generator = generator.manual_seed(args.seed)
images = []
- with torch.cuda.amp.autocast():
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
for _ in range(args.num_validation_images):
images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py
index c9860b744c03..0a6a70de2dc7 100644
--- a/examples/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -21,6 +21,7 @@
import os
import random
import shutil
+from contextlib import nullcontext
from pathlib import Path
import datasets
@@ -979,13 +980,6 @@ def collate_fn(examples):
if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
- # Some configurations require autocast to be disabled.
- enable_autocast = True
- if torch.backends.mps.is_available() or (
- accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
- ):
- enable_autocast = False
-
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -1211,11 +1205,12 @@ def compute_time_ids(original_size, crops_coords_top_left):
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
- with torch.autocast(
- accelerator.device.type,
- enabled=enable_autocast,
- ):
+ with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py
index c141f5bdd706..88adbb995531 100644
--- a/examples/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_sdxl.py
@@ -23,6 +23,7 @@
import os
import random
import shutil
+from contextlib import nullcontext
from pathlib import Path
import accelerate
@@ -603,6 +604,10 @@ def main(args):
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
@@ -986,12 +991,10 @@ def unwrap_model(model):
model = model._orig_mod if is_compiled_module(model) else model
return model
- # Some configurations require autocast to be disabled.
- enable_autocast = True
- if torch.backends.mps.is_available() or (
- accelerator.mixed_precision == "fp16" or accelerator.mixed_precision == "bf16"
- ):
- enable_autocast = False
+ if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -1226,10 +1229,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}
- with torch.autocast(
- accelerator.device.type,
- enabled=enable_autocast,
- ):
+ with autocast_ctx:
images = [
pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
for _ in range(args.num_validation_images)
@@ -1252,6 +1252,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
del pipeline
torch.cuda.empty_cache()
+ if args.use_ema:
+ # Switch back to the original UNet parameters.
+ ema_unet.restore(unet.parameters())
+
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unwrap_model(unet)
@@ -1284,7 +1288,8 @@ def compute_time_ids(original_size, crops_coords_top_left):
if args.validation_prompt and args.num_validation_images > 0:
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
- with torch.autocast(accelerator.device.type, enabled=enable_autocast):
+
+ with autocast_ctx:
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images)
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index 0f4bb7604f3c..4922789862b5 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -20,6 +20,7 @@
import random
import shutil
import warnings
+from contextlib import nullcontext
from pathlib import Path
import numpy as np
@@ -143,7 +144,12 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
- with torch.autocast("cuda"):
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
images.append(image)
@@ -600,6 +606,10 @@ def main():
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py
index 460acf9f8009..c24a4c4f4855 100644
--- a/examples/textual_inversion/textual_inversion_sdxl.py
+++ b/examples/textual_inversion/textual_inversion_sdxl.py
@@ -605,6 +605,10 @@ def main():
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
index e7d5898e1118..76eaf6423960 100644
--- a/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
@@ -460,6 +460,10 @@ def main():
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
index 7aaebed3b085..49cc5d26072d 100644
--- a/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
+++ b/examples/wuerstchen/text_to_image/train_text_to_image_prior.py
@@ -458,6 +458,10 @@ def main():
project_config=accelerator_project_config,
)
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
diff --git a/pyproject.toml b/pyproject.toml
index 0612f2f9e059..299865a1225d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,15 +1,17 @@
[tool.ruff]
+line-length = 119
+
+[tool.ruff.lint]
# Never enforce `E501` (line length violations).
ignore = ["C901", "E501", "E741", "F402", "F823"]
select = ["C", "E", "F", "I", "W"]
-line-length = 119
# Ignore import violations in all `__init__.py` files.
-[tool.ruff.per-file-ignores]
+[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E402", "F401", "F403", "F811"]
"src/diffusers/utils/dummy_*.py" = ["F401"]
-[tool.ruff.isort]
+[tool.ruff.lint.isort]
lines-after-imports = 2
known-first-party = ["diffusers"]
diff --git a/setup.py b/setup.py
index bbf8ecfde174..943238df765d 100644
--- a/setup.py
+++ b/setup.py
@@ -95,7 +95,7 @@
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
_deps = [
"Pillow", # keep the PIL.Image.Resampling deprecation away
- "accelerate>=0.11.0",
+ "accelerate>=0.29.3",
"compel==0.1.8",
"datasets",
"filelock",
@@ -134,6 +134,7 @@
"torchvision",
"transformers>=4.25.1",
"urllib3<=2.0.0",
+ "black",
]
# this is a lookup table with items like:
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 770045923d5d..5d6761663938 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -80,6 +80,7 @@
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
+ "ControlNetXSAdapter",
"I2VGenXLUNet",
"Kandinsky3UNet",
"ModelMixin",
@@ -94,6 +95,7 @@
"UNet2DConditionModel",
"UNet2DModel",
"UNet3DConditionModel",
+ "UNetControlNetXSModel",
"UNetMotionModel",
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
@@ -270,6 +272,7 @@
"StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline",
"StableDiffusionControlNetPipeline",
+ "StableDiffusionControlNetXSPipeline",
"StableDiffusionDepth2ImgPipeline",
"StableDiffusionDiffEditPipeline",
"StableDiffusionGLIGENPipeline",
@@ -293,6 +296,7 @@
"StableDiffusionXLControlNetImg2ImgPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
"StableDiffusionXLControlNetPipeline",
+ "StableDiffusionXLControlNetXSPipeline",
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
@@ -474,6 +478,7 @@
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
+ ControlNetXSAdapter,
I2VGenXLUNet,
Kandinsky3UNet,
ModelMixin,
@@ -487,6 +492,7 @@
UNet2DConditionModel,
UNet2DModel,
UNet3DConditionModel,
+ UNetControlNetXSModel,
UNetMotionModel,
UNetSpatioTemporalConditionModel,
UVit2DModel,
@@ -642,6 +648,7 @@
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
+ StableDiffusionControlNetXSPipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline,
StableDiffusionGLIGENPipeline,
@@ -665,6 +672,7 @@
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
+ StableDiffusionXLControlNetXSPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py
index e92a486bffc1..ca233a4158bc 100644
--- a/src/diffusers/dependency_versions_table.py
+++ b/src/diffusers/dependency_versions_table.py
@@ -3,7 +3,7 @@
# 2. run `make deps_table_update`
deps = {
"Pillow": "Pillow",
- "accelerate": "accelerate>=0.11.0",
+ "accelerate": "accelerate>=0.29.3",
"compel": "compel==0.1.8",
"datasets": "datasets",
"filelock": "filelock",
@@ -42,4 +42,5 @@
"torchvision": "torchvision",
"transformers": "transformers>=4.25.1",
"urllib3": "urllib3<=2.0.0",
+ "black": "black",
}
diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py
index daeb8fd6fa6d..eac3f9b7d578 100644
--- a/src/diffusers/image_processor.py
+++ b/src/diffusers/image_processor.py
@@ -173,8 +173,9 @@ def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
@staticmethod
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
"""
- Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image;
- for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.
+ Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
+ ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
+ processing are 512x512, the region will be expanded to 128x128.
Args:
mask_image (PIL.Image.Image): Mask image.
@@ -183,7 +184,8 @@ def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0)
pad (int, optional): Padding to be added to the crop region. Defaults to 0.
Returns:
- tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio.
+ tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
+ matches the original aspect ratio.
"""
mask_image = mask_image.convert("L")
@@ -265,7 +267,8 @@ def _resize_and_fill(
height: int,
) -> PIL.Image.Image:
"""
- Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
+ the image within the dimensions, filling empty with data from image.
Args:
image: The image to resize.
@@ -309,7 +312,8 @@ def _resize_and_crop(
height: int,
) -> PIL.Image.Image:
"""
- Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
+ Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
+ the image within the dimensions, cropping the excess.
Args:
image: The image to resize.
@@ -346,12 +350,12 @@ def resize(
The width to resize to.
resize_mode (`str`, *optional*, defaults to `default`):
The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
- within the specified width and height, and it may not maintaining the original aspect ratio.
- If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
- within the dimensions, filling empty with data from image.
- If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
- within the dimensions, cropping the excess.
- Note that resize_mode `fill` and `crop` are only supported for PIL image input.
+ within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
+ will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
+ then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
+ the image to fit within the specified width and height, maintaining the aspect ratio, and then center
+ the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
+ supported for PIL image input.
Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
@@ -456,19 +460,21 @@ def preprocess(
Args:
image (`pipeline_image_input`):
- The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats.
+ The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
+ supported formats.
height (`int`, *optional*, defaults to `None`):
- The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
+ height.
width (`int`, *optional*`, defaults to `None`):
- The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
resize_mode (`str`, *optional*, defaults to `default`):
- The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
- within the specified width and height, and it may not maintaining the original aspect ratio.
- If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
- within the dimensions, filling empty with data from image.
- If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
- within the dimensions, cropping the excess.
- Note that resize_mode `fill` and `crop` are only supported for PIL image input.
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
+ the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
+ resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
+ center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
+ image to fit within the specified width and height, maintaining the aspect ratio, and then center the
+ image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
+ supported for PIL image input.
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
The crop coordinates for each image in the batch. If `None`, will not crop the image.
"""
@@ -930,8 +936,8 @@ def __init__(
@staticmethod
def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
"""
- Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention.
- If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
+ Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
+ aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
Args:
mask (`torch.FloatTensor`):
diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py
index a4593ec69404..fdddc382212f 100644
--- a/src/diffusers/loaders/ip_adapter.py
+++ b/src/diffusers/loaders/ip_adapter.py
@@ -21,6 +21,7 @@
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import (
+ USE_PEFT_BACKEND,
_get_model_file,
is_accelerate_available,
is_torch_version,
@@ -67,17 +68,18 @@ def load_ip_adapter(
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `List[str]`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- If a list is passed, it should have the same length as `weight_name`.
+ The subfolder location of a model file within a larger model repository on the Hub or locally. If a
+ list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`weight_name`.
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
- Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`,
- you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder="image_encoder"`.
- If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights,
- for example, `image_encoder_folder="different_subfolder/image_encoder"`.
+ Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
+ `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
+ `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
+ `subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
+ `image_encoder_folder="different_subfolder/image_encoder"`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
@@ -227,6 +229,18 @@ def load_ip_adapter(
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
+ extra_loras = unet._load_ip_adapter_loras(state_dicts)
+ if extra_loras != {}:
+ if not USE_PEFT_BACKEND:
+ logger.warning("PEFT backend is required to load these weights.")
+ else:
+ # apply the IP Adapter Face ID LoRA weights
+ peft_config = getattr(unet, "peft_config", {})
+ for k, lora in extra_loras.items():
+ if f"faceid_{k}" not in peft_config:
+ self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
+ self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
+
def set_ip_adapter_scale(self, scale):
"""
Sets the conditioning scale between text and image.
diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py
index df7dfbcd8871..5d89658830f1 100644
--- a/src/diffusers/loaders/lora.py
+++ b/src/diffusers/loaders/lora.py
@@ -1267,6 +1267,10 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
for adapter_name in adapter_names:
unet_module.lora_A[adapter_name].to(device)
unet_module.lora_B[adapter_name].to(device)
+ # this is a param, not a module, so device placement is not in-place -> re-assign
+ unet_module.lora_magnitude_vector[adapter_name] = unet_module.lora_magnitude_vector[
+ adapter_name
+ ].to(device)
# Handle the text encoder
modules_to_process = []
@@ -1283,6 +1287,10 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
for adapter_name in adapter_names:
text_encoder_module.lora_A[adapter_name].to(device)
text_encoder_module.lora_B[adapter_name].to(device)
+ # this is a param, not a module, so device placement is not in-place -> re-assign
+ text_encoder_module.lora_magnitude_vector[
+ adapter_name
+ ] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device)
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index 01dbd3494a4c..5892c2865374 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -20,7 +20,8 @@
class PeftAdapterMixin:
"""
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
- more details about adapters and injecting them in a transformer-based model, check out the PEFT [documentation](https://huggingface.co/docs/peft/index).
+ more details about adapters and injecting them in a transformer-based model, check out the PEFT
+ [documentation](https://huggingface.co/docs/peft/index).
Install the latest version of PEFT, and use this mixin to:
@@ -143,8 +144,8 @@ def disable_adapters(self) -> None:
def enable_adapters(self) -> None:
"""
- Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the
- list of adapters to enable.
+ Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of
+ adapters to enable.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
[documentation](https://huggingface.co/docs/peft).
diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py
index 0d384b1647d5..752ef18c7a0b 100644
--- a/src/diffusers/loaders/single_file.py
+++ b/src/diffusers/loaders/single_file.py
@@ -198,19 +198,24 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
model_type (`str`, *optional*):
The type of model to load. If not provided, the model type will be inferred from the checkpoint file.
image_size (`int`, *optional*):
- The size of the image output. It's used to configure the `sample_size` parameter of the UNet and VAE model.
+ The size of the image output. It's used to configure the `sample_size` parameter of the UNet and VAE
+ model.
load_safety_checker (`bool`, *optional*, defaults to `False`):
- Whether to load the safety checker model or not. By default, the safety checker is not loaded unless a `safety_checker` component is passed to the `kwargs`.
+ Whether to load the safety checker model or not. By default, the safety checker is not loaded unless a
+ `safety_checker` component is passed to the `kwargs`.
num_in_channels (`int`, *optional*):
- Specify the number of input channels for the UNet model. Read more about how to configure UNet model with this parameter
+ Specify the number of input channels for the UNet model. Read more about how to configure UNet model
+ with this parameter
[here](https://huggingface.co/docs/diffusers/training/adapt_a_model#configure-unet2dconditionmodel-parameters).
scaling_factor (`float`, *optional*):
- The scaling factor to use for the VAE model. If not provided, it is inferred from the config file first.
- If the scaling factor is not found in the config file, the default value 0.18215 is used.
+ The scaling factor to use for the VAE model. If not provided, it is inferred from the config file
+ first. If the scaling factor is not found in the config file, the default value 0.18215 is used.
scheduler_type (`str`, *optional*):
- The type of scheduler to load. If not provided, the scheduler type will be inferred from the checkpoint file.
+ The type of scheduler to load. If not provided, the scheduler type will be inferred from the checkpoint
+ file.
prediction_type (`str`, *optional*):
- The type of prediction to load. If not provided, the prediction type will be inferred from the checkpoint file.
+ The type of prediction to load. If not provided, the prediction type will be inferred from the
+ checkpoint file.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
diff --git a/src/diffusers/loaders/textual_inversion.py b/src/diffusers/loaders/textual_inversion.py
index eb727990af18..c1c224975cb8 100644
--- a/src/diffusers/loaders/textual_inversion.py
+++ b/src/diffusers/loaders/textual_inversion.py
@@ -487,20 +487,35 @@ def unload_textual_inversion(
# Example 3: unload from SDXL
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
- embedding_path = hf_hub_download(repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model")
+ embedding_path = hf_hub_download(
+ repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model"
+ )
# load embeddings to the text encoders
state_dict = load_file(embedding_path)
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
- pipeline.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
+ pipeline.load_textual_inversion(
+ state_dict["clip_l"],
+ token=["", ""],
+ text_encoder=pipeline.text_encoder,
+ tokenizer=pipeline.tokenizer,
+ )
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
- pipeline.load_textual_inversion(state_dict["clip_g"], token=["", ""], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
+ pipeline.load_textual_inversion(
+ state_dict["clip_g"],
+ token=["", ""],
+ text_encoder=pipeline.text_encoder_2,
+ tokenizer=pipeline.tokenizer_2,
+ )
# Unload explicitly from both text encoders abd tokenizers
- pipeline.unload_textual_inversion(tokens=["", ""], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
- pipeline.unload_textual_inversion(tokens=["", ""], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
-
+ pipeline.unload_textual_inversion(
+ tokens=["", ""], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer
+ )
+ pipeline.unload_textual_inversion(
+ tokens=["", ""], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2
+ )
```
"""
diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py
index 8bbec26189b0..294db44ee61d 100644
--- a/src/diffusers/loaders/unet.py
+++ b/src/diffusers/loaders/unet.py
@@ -27,6 +27,8 @@
from ..models.embeddings import (
ImageProjection,
+ IPAdapterFaceIDImageProjection,
+ IPAdapterFaceIDPlusImageProjection,
IPAdapterFullImageProjection,
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
@@ -756,6 +758,90 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
diffusers_name = diffusers_name.replace("proj.3", "norm")
updated_state_dict[diffusers_name] = value
+ elif "perceiver_resampler.proj_in.weight" in state_dict:
+ # IP-Adapter Face ID Plus
+ id_embeddings_dim = state_dict["proj.0.weight"].shape[1]
+ embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0]
+ hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1]
+ output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0]
+ heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64
+
+ with init_context():
+ image_projection = IPAdapterFaceIDPlusImageProjection(
+ embed_dims=embed_dims,
+ output_dims=output_dims,
+ hidden_dims=hidden_dims,
+ heads=heads,
+ id_embeddings_dim=id_embeddings_dim,
+ )
+
+ for key, value in state_dict.items():
+ diffusers_name = key.replace("perceiver_resampler.", "")
+ diffusers_name = diffusers_name.replace("0.to", "attn.to")
+ diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.")
+ diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight")
+ diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight")
+ diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.")
+ diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight")
+ diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight")
+ diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.")
+ diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight")
+ diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight")
+ diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.")
+ diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight")
+ diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight")
+ diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0")
+ diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1")
+ diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0")
+ diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1")
+ diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0")
+ diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1")
+ diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0")
+ diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1")
+
+ if "norm1" in diffusers_name:
+ updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
+ elif "norm2" in diffusers_name:
+ updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
+ elif "to_kv" in diffusers_name:
+ v_chunk = value.chunk(2, dim=0)
+ updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
+ updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
+ elif "to_out" in diffusers_name:
+ updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
+ elif "proj.0.weight" == diffusers_name:
+ updated_state_dict["proj.net.0.proj.weight"] = value
+ elif "proj.0.bias" == diffusers_name:
+ updated_state_dict["proj.net.0.proj.bias"] = value
+ elif "proj.2.weight" == diffusers_name:
+ updated_state_dict["proj.net.2.weight"] = value
+ elif "proj.2.bias" == diffusers_name:
+ updated_state_dict["proj.net.2.bias"] = value
+ else:
+ updated_state_dict[diffusers_name] = value
+
+ elif "norm.weight" in state_dict:
+ # IP-Adapter Face ID
+ id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1]
+ id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0]
+ multiplier = id_embeddings_dim_out // id_embeddings_dim_in
+ norm_layer = "norm.weight"
+ cross_attention_dim = state_dict[norm_layer].shape[0]
+ num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim
+
+ with init_context():
+ image_projection = IPAdapterFaceIDImageProjection(
+ cross_attention_dim=cross_attention_dim,
+ image_embed_dim=id_embeddings_dim_in,
+ mult=multiplier,
+ num_tokens=num_tokens,
+ )
+
+ for key, value in state_dict.items():
+ diffusers_name = key.replace("proj.0", "ff.net.0.proj")
+ diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
+ updated_state_dict[diffusers_name] = value
+
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["latents"].shape[1]
@@ -847,6 +933,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
+
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
@@ -859,6 +946,12 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
elif "proj.3.weight" in state_dict["image_proj"]:
# IP-Adapter Full Face
num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token
+ elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]:
+ # IP-Adapter Face ID Plus
+ num_image_text_embeds += [4]
+ elif "norm.weight" in state_dict["image_proj"]:
+ # IP-Adapter Face ID
+ num_image_text_embeds += [4]
else:
# IP-Adapter Plus
num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]
@@ -910,6 +1003,59 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
self.to(dtype=self.dtype, device=self.device)
+ def _load_ip_adapter_loras(self, state_dicts):
+ lora_dicts = {}
+ for key_id, name in enumerate(self.attn_processors.keys()):
+ for i, state_dict in enumerate(state_dicts):
+ if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]:
+ if i not in lora_dicts:
+ lora_dicts[i] = {}
+ lora_dicts[i].update(
+ {
+ f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][
+ f"{key_id}.to_k_lora.down.weight"
+ ]
+ }
+ )
+ lora_dicts[i].update(
+ {
+ f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][
+ f"{key_id}.to_q_lora.down.weight"
+ ]
+ }
+ )
+ lora_dicts[i].update(
+ {
+ f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][
+ f"{key_id}.to_v_lora.down.weight"
+ ]
+ }
+ )
+ lora_dicts[i].update(
+ {
+ f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
+ f"{key_id}.to_out_lora.down.weight"
+ ]
+ }
+ )
+ lora_dicts[i].update(
+ {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
+ )
+ lora_dicts[i].update(
+ {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
+ )
+ lora_dicts[i].update(
+ {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
+ )
+ lora_dicts[i].update(
+ {
+ f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][
+ f"{key_id}.to_out_lora.up.weight"
+ ]
+ }
+ )
+ return lora_dicts
+
class FromOriginalUNetMixin:
"""
@@ -998,7 +1144,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
if len(unexpected_keys) > 0:
- logger.warn(
+ logger.warning(
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)
diff --git a/src/diffusers/loaders/unet_loader_utils.py b/src/diffusers/loaders/unet_loader_utils.py
index 918a0fca06c8..3ee4a96fad0a 100644
--- a/src/diffusers/loaders/unet_loader_utils.py
+++ b/src/diffusers/loaders/unet_loader_utils.py
@@ -74,37 +74,24 @@ def _maybe_expand_lora_scales_for_one_adapter(
E.g. turns
```python
- scales = {
- 'down': 2,
- 'mid': 3,
- 'up': {
- 'block_0': 4,
- 'block_1': [5, 6, 7]
- }
- }
- blocks_with_transformer = {
- 'down': [1,2],
- 'up': [0,1]
- }
- transformer_per_block = {
- 'down': 2,
- 'up': 3
- }
+ scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}}
+ blocks_with_transformer = {"down": [1, 2], "up": [0, 1]}
+ transformer_per_block = {"down": 2, "up": 3}
```
into
```python
{
- 'down.block_1.0': 2,
- 'down.block_1.1': 2,
- 'down.block_2.0': 2,
- 'down.block_2.1': 2,
- 'mid': 3,
- 'up.block_0.0': 4,
- 'up.block_0.1': 4,
- 'up.block_0.2': 4,
- 'up.block_1.0': 5,
- 'up.block_1.1': 6,
- 'up.block_1.2': 7,
+ "down.block_1.0": 2,
+ "down.block_1.1": 2,
+ "down.block_2.0": 2,
+ "down.block_2.1": 2,
+ "mid": 3,
+ "up.block_0.0": 4,
+ "up.block_0.1": 4,
+ "up.block_0.2": 4,
+ "up.block_1.0": 5,
+ "up.block_1.1": 6,
+ "up.block_1.2": 7,
}
```
"""
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index da77e4450c86..78b0efff921d 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -32,6 +32,7 @@
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
+ _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
@@ -68,6 +69,7 @@
ConsistencyDecoderVAE,
)
from .controlnet import ControlNetModel
+ from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import (
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 1fd29ce708c8..237e8236caf4 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -13,7 +13,7 @@
# limitations under the License.
import inspect
from importlib import import_module
-from typing import Callable, Optional, Union
+from typing import Callable, List, Optional, Union
import torch
import torch.nn.functional as F
@@ -1298,9 +1298,9 @@ def __call__(
class FusedAttnProcessor2_0:
r"""
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
- It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
+ fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
+ For cross-attention modules, key and value projection matrices are fused.
@@ -2195,15 +2195,33 @@ def __call__(
hidden_states = attn.batch_to_head_dim(hidden_states)
if ip_adapter_masks is not None:
- if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
- raise ValueError(
- " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
- " Please use `IPAdapterMaskProcessor` to preprocess your mask"
- )
- if len(ip_adapter_masks) != len(self.scale):
+ if not isinstance(ip_adapter_masks, List):
+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
+ ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
+ if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
raise ValueError(
- f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
+ f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
+ f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
+ f"({len(ip_hidden_states)})"
)
+ else:
+ for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
+ if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
+ raise ValueError(
+ "Each element of the ip_adapter_masks array should be a tensor with shape "
+ "[1, num_images_for_ip_adapter, height, width]."
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
+ )
+ if mask.shape[1] != ip_state.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of ip images ({ip_state.shape[1]}) at index {index}"
+ )
+ if isinstance(scale, list) and not len(scale) == mask.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of scales ({len(scale)}) at index {index}"
+ )
else:
ip_adapter_masks = [None] * len(self.scale)
@@ -2211,26 +2229,44 @@ def __call__(
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
- ip_key = to_k_ip(current_ip_hidden_states)
- ip_value = to_v_ip(current_ip_hidden_states)
-
- ip_key = attn.head_to_batch_dim(ip_key)
- ip_value = attn.head_to_batch_dim(ip_value)
+ if mask is not None:
+ if not isinstance(scale, list):
+ scale = [scale]
+
+ current_num_images = mask.shape[1]
+ for i in range(current_num_images):
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
+
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
+
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ _current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ _current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
+
+ mask_downsample = IPAdapterMaskProcessor.downsample(
+ mask[:, i, :, :],
+ batch_size,
+ _current_ip_hidden_states.shape[1],
+ _current_ip_hidden_states.shape[2],
+ )
- ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
- current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
- current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
- if mask is not None:
- mask_downsample = IPAdapterMaskProcessor.downsample(
- mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
- )
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
+ else:
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
- mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
- current_ip_hidden_states = current_ip_hidden_states * mask_downsample
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
- hidden_states = hidden_states + scale * current_ip_hidden_states
+ hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
@@ -2369,15 +2405,33 @@ def __call__(
hidden_states = hidden_states.to(query.dtype)
if ip_adapter_masks is not None:
- if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
+ if not isinstance(ip_adapter_masks, List):
+ # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
+ ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
+ if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
raise ValueError(
- " ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
- " Please use `IPAdapterMaskProcessor` to preprocess your mask"
- )
- if len(ip_adapter_masks) != len(self.scale):
- raise ValueError(
- f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
+ f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
+ f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
+ f"({len(ip_hidden_states)})"
)
+ else:
+ for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
+ if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
+ raise ValueError(
+ "Each element of the ip_adapter_masks array should be a tensor with shape "
+ "[1, num_images_for_ip_adapter, height, width]."
+ " Please use `IPAdapterMaskProcessor` to preprocess your mask"
+ )
+ if mask.shape[1] != ip_state.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of ip images ({ip_state.shape[1]}) at index {index}"
+ )
+ if isinstance(scale, list) and not len(scale) == mask.shape[1]:
+ raise ValueError(
+ f"Number of masks ({mask.shape[1]}) does not match "
+ f"number of scales ({len(scale)}) at index {index}"
+ )
else:
ip_adapter_masks = [None] * len(self.scale)
@@ -2385,33 +2439,57 @@ def __call__(
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
- ip_key = to_k_ip(current_ip_hidden_states)
- ip_value = to_v_ip(current_ip_hidden_states)
+ if mask is not None:
+ if not isinstance(scale, list):
+ scale = [scale]
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ current_num_images = mask.shape[1]
+ for i in range(current_num_images):
+ ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
+ ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- current_ip_hidden_states = F.scaled_dot_product_attention(
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
- batch_size, -1, attn.heads * head_dim
- )
- current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ _current_ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
- if mask is not None:
- mask_downsample = IPAdapterMaskProcessor.downsample(
- mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
- )
+ _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
- mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
+ mask_downsample = IPAdapterMaskProcessor.downsample(
+ mask[:, i, :, :],
+ batch_size,
+ _current_ip_hidden_states.shape[1],
+ _current_ip_hidden_states.shape[2],
+ )
- current_ip_hidden_states = current_ip_hidden_states * mask_downsample
+ mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
+ hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
+ else:
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ current_ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+ current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
+ batch_size, -1, attn.heads * head_dim
+ )
+ current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
- hidden_states = hidden_states + scale * current_ip_hidden_states
+ hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py
index 9bbf2023eb99..b286453de424 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl.py
@@ -453,8 +453,8 @@ def forward(
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
diff --git a/src/diffusers/models/autoencoders/autoencoder_tiny.py b/src/diffusers/models/autoencoders/autoencoder_tiny.py
index ef43526cf8a0..a7047acdfd74 100644
--- a/src/diffusers/models/autoencoders/autoencoder_tiny.py
+++ b/src/diffusers/models/autoencoders/autoencoder_tiny.py
@@ -102,6 +102,7 @@ def __init__(
encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
act_fn: str = "relu",
+ upsample_fn: str = "nearest",
latent_channels: int = 4,
upsampling_scaling_factor: int = 2,
num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
@@ -133,6 +134,7 @@ def __init__(
block_out_channels=decoder_block_out_channels,
upsampling_scaling_factor=upsampling_scaling_factor,
act_fn=act_fn,
+ upsample_fn=upsample_fn,
)
self.latent_magnitude = latent_magnitude
diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py
index cf2e37337c16..826a6365b726 100644
--- a/src/diffusers/models/autoencoders/vae.py
+++ b/src/diffusers/models/autoencoders/vae.py
@@ -926,6 +926,7 @@ def __init__(
block_out_channels: Tuple[int, ...],
upsampling_scaling_factor: int,
act_fn: str,
+ upsample_fn: str,
):
super().__init__()
@@ -942,7 +943,7 @@ def __init__(
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
if not is_final_block:
- layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
+ layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor, mode=upsample_fn))
conv_out_channel = num_channels if not is_final_block else out_channels
layers.append(
diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnet_flax.py
index 6f9b201aa1e3..0540850a9e61 100644
--- a/src/diffusers/models/controlnet_flax.py
+++ b/src/diffusers/models/controlnet_flax.py
@@ -329,15 +329,15 @@ def __call__(
controlnet_cond (`jnp.ndarray`): (batch, channel, height, width) the conditional input tensor
conditioning_scale (`float`, *optional*, defaults to `1.0`): the scale factor for controlnet outputs
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
- plain tuple.
+ Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of
+ a plain tuple.
train (`bool`, *optional*, defaults to `False`):
Use deterministic functions and disable dropout when not training.
Returns:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
- [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
- `tuple`. When returning a tuple, the first element is the sample tensor.
+ [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise
+ a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
channel_order = self.controlnet_conditioning_channel_order
if channel_order == "bgr":
diff --git a/src/diffusers/models/controlnet_xs.py b/src/diffusers/models/controlnet_xs.py
new file mode 100644
index 000000000000..a4f9e61f37c7
--- /dev/null
+++ b/src/diffusers/models/controlnet_xs.py
@@ -0,0 +1,1915 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from math import gcd
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import FloatTensor, nn
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, is_torch_version, logging
+from ..utils.torch_utils import apply_freeu
+from .attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ Attention,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from .controlnet import ControlNetConditioningEmbedding
+from .embeddings import TimestepEmbedding, Timesteps
+from .modeling_utils import ModelMixin
+from .unets.unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ CrossAttnUpBlock2D,
+ Downsample2D,
+ ResnetBlock2D,
+ Transformer2DModel,
+ UNetMidBlock2DCrossAttn,
+ Upsample2D,
+)
+from .unets.unet_2d_condition import UNet2DConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class ControlNetXSOutput(BaseOutput):
+ """
+ The output of [`UNetControlNetXSModel`].
+
+ Args:
+ sample (`FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base
+ model output, but is already the final output.
+ """
+
+ sample: FloatTensor = None
+
+
+class DownBlockControlNetXSAdapter(nn.Module):
+ """Components that together with corresponding components from the base model will form a
+ `ControlNetXSCrossAttnDownBlock2D`"""
+
+ def __init__(
+ self,
+ resnets: nn.ModuleList,
+ base_to_ctrl: nn.ModuleList,
+ ctrl_to_base: nn.ModuleList,
+ attentions: Optional[nn.ModuleList] = None,
+ downsampler: Optional[nn.Conv2d] = None,
+ ):
+ super().__init__()
+ self.resnets = resnets
+ self.base_to_ctrl = base_to_ctrl
+ self.ctrl_to_base = ctrl_to_base
+ self.attentions = attentions
+ self.downsamplers = downsampler
+
+
+class MidBlockControlNetXSAdapter(nn.Module):
+ """Components that together with corresponding components from the base model will form a
+ `ControlNetXSCrossAttnMidBlock2D`"""
+
+ def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl: nn.ModuleList, ctrl_to_base: nn.ModuleList):
+ super().__init__()
+ self.midblock = midblock
+ self.base_to_ctrl = base_to_ctrl
+ self.ctrl_to_base = ctrl_to_base
+
+
+class UpBlockControlNetXSAdapter(nn.Module):
+ """Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnUpBlock2D`"""
+
+ def __init__(self, ctrl_to_base: nn.ModuleList):
+ super().__init__()
+ self.ctrl_to_base = ctrl_to_base
+
+
+def get_down_block_adapter(
+ base_in_channels: int,
+ base_out_channels: int,
+ ctrl_in_channels: int,
+ ctrl_out_channels: int,
+ temb_channels: int,
+ max_norm_num_groups: Optional[int] = 32,
+ has_crossattn=True,
+ transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1,
+ num_attention_heads: Optional[int] = 1,
+ cross_attention_dim: Optional[int] = 1024,
+ add_downsample: bool = True,
+ upcast_attention: Optional[bool] = False,
+):
+ num_layers = 2 # only support sd + sdxl
+
+ resnets = []
+ attentions = []
+ ctrl_to_base = []
+ base_to_ctrl = []
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ base_in_channels = base_in_channels if i == 0 else base_out_channels
+ ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels
+
+ # Before the resnet/attention application, information is concatted from base to control.
+ # Concat doesn't require change in number of channels
+ base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels))
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl
+ out_channels=ctrl_out_channels,
+ temb_channels=temb_channels,
+ groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups),
+ groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
+ eps=1e-5,
+ )
+ )
+
+ if has_crossattn:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ ctrl_out_channels // num_attention_heads,
+ in_channels=ctrl_out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ use_linear_projection=True,
+ upcast_attention=upcast_attention,
+ norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups),
+ )
+ )
+
+ # After the resnet/attention application, information is added from control to base
+ # Addition requires change in number of channels
+ ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
+
+ if add_downsample:
+ # Before the downsampler application, information is concatted from base to control
+ # Concat doesn't require change in number of channels
+ base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels))
+
+ downsamplers = Downsample2D(
+ ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op"
+ )
+
+ # After the downsampler application, information is added from control to base
+ # Addition requires change in number of channels
+ ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
+ else:
+ downsamplers = None
+
+ down_block_components = DownBlockControlNetXSAdapter(
+ resnets=nn.ModuleList(resnets),
+ base_to_ctrl=nn.ModuleList(base_to_ctrl),
+ ctrl_to_base=nn.ModuleList(ctrl_to_base),
+ )
+
+ if has_crossattn:
+ down_block_components.attentions = nn.ModuleList(attentions)
+ if downsamplers is not None:
+ down_block_components.downsamplers = downsamplers
+
+ return down_block_components
+
+
+def get_mid_block_adapter(
+ base_channels: int,
+ ctrl_channels: int,
+ temb_channels: Optional[int] = None,
+ max_norm_num_groups: Optional[int] = 32,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: Optional[int] = 1,
+ cross_attention_dim: Optional[int] = 1024,
+ upcast_attention: bool = False,
+):
+ # Before the midblock application, information is concatted from base to control.
+ # Concat doesn't require change in number of channels
+ base_to_ctrl = make_zero_conv(base_channels, base_channels)
+
+ midblock = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=ctrl_channels + base_channels,
+ out_channels=ctrl_channels,
+ temb_channels=temb_channels,
+ # number or norm groups must divide both in_channels and out_channels
+ resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups),
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ use_linear_projection=True,
+ upcast_attention=upcast_attention,
+ )
+
+ # After the midblock application, information is added from control to base
+ # Addition requires change in number of channels
+ ctrl_to_base = make_zero_conv(ctrl_channels, base_channels)
+
+ return MidBlockControlNetXSAdapter(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base)
+
+
+def get_up_block_adapter(
+ out_channels: int,
+ prev_output_channel: int,
+ ctrl_skip_channels: List[int],
+):
+ ctrl_to_base = []
+ num_layers = 3 # only support sd + sdxl
+ for i in range(num_layers):
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+ ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels))
+
+ return UpBlockControlNetXSAdapter(ctrl_to_base=nn.ModuleList(ctrl_to_base))
+
+
+class ControlNetXSAdapter(ModelMixin, ConfigMixin):
+ r"""
+ A `ControlNetXSAdapter` model. To use it, pass it into a `UNetControlNetXSModel` (together with a
+ `UNet2DConditionModel` base model).
+
+ This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic
+ methods implemented for all models (such as downloading or saving).
+
+ Like `UNetControlNetXSModel`, `ControlNetXSAdapter` is compatible with StableDiffusion and StableDiffusion-XL. It's
+ default parameters are compatible with StableDiffusion.
+
+ Parameters:
+ conditioning_channels (`int`, defaults to 3):
+ Number of channels of conditioning input (e.g. an image)
+ conditioning_channel_order (`str`, defaults to `"rgb"`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channels for each block in the `controlnet_cond_embedding` layer.
+ time_embedding_mix (`float`, defaults to 1.0):
+ If 0, then only the control adapters's time embedding is used. If 1, then only the base unet's time
+ embedding is used. Otherwise, both are combined.
+ learn_time_embedding (`bool`, defaults to `False`):
+ Whether a time embedding should be learned. If yes, `UNetControlNetXSModel` will combine the time
+ embeddings of the base model and the control adapter. If no, `UNetControlNetXSModel` will use the base
+ model's time embedding.
+ num_attention_heads (`list[int]`, defaults to `[4]`):
+ The number of attention heads.
+ block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`):
+ The tuple of output channels for each block.
+ base_block_out_channels (`list[int]`, defaults to `[320, 640, 1280, 1280]`):
+ The tuple of output channels for each block in the base unet.
+ cross_attention_dim (`int`, defaults to 1024):
+ The dimension of the cross attention features.
+ down_block_types (`list[str]`, defaults to `["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]`):
+ The tuple of downsample blocks to use.
+ sample_size (`int`, defaults to 96):
+ Height and width of input/output sample.
+ transformer_layers_per_block (`Union[int, Tuple[int]]`, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ upcast_attention (`bool`, defaults to `True`):
+ Whether the attention computation should always be upcasted.
+ max_norm_num_groups (`int`, defaults to 32):
+ Maximum number of groups in group normal. The actual number will the the largest divisor of the respective
+ channels, that is <= max_norm_num_groups.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ conditioning_channels: int = 3,
+ conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
+ time_embedding_mix: float = 1.0,
+ learn_time_embedding: bool = False,
+ num_attention_heads: Union[int, Tuple[int]] = 4,
+ block_out_channels: Tuple[int] = (4, 8, 16, 16),
+ base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ cross_attention_dim: int = 1024,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ sample_size: Optional[int] = 96,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ upcast_attention: bool = True,
+ max_norm_num_groups: int = 32,
+ ):
+ super().__init__()
+
+ time_embedding_input_dim = base_block_out_channels[0]
+ time_embedding_dim = base_block_out_channels[0] * 4
+
+ # Check inputs
+ if conditioning_channel_order not in ["rgb", "bgr"]:
+ raise ValueError(f"unknown `conditioning_channel_order`: {conditioning_channel_order}")
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(transformer_layers_per_block, (list, tuple)):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+ if not isinstance(cross_attention_dim, (list, tuple)):
+ cross_attention_dim = [cross_attention_dim] * len(down_block_types)
+ # see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why `ControlNetXSAdapter` takes `num_attention_heads` instead of `attention_head_dim`
+ if not isinstance(num_attention_heads, (list, tuple)):
+ num_attention_heads = [num_attention_heads] * len(down_block_types)
+
+ if len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ # 5 - Create conditioning hint embedding
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
+ conditioning_embedding_channels=block_out_channels[0],
+ block_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ # time
+ if learn_time_embedding:
+ self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embedding_dim)
+ else:
+ self.time_embedding = None
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_connections = nn.ModuleList([])
+
+ # input
+ self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1)
+ self.control_to_base_for_conv_in = make_zero_conv(block_out_channels[0], base_block_out_channels[0])
+
+ # down
+ base_out_channels = base_block_out_channels[0]
+ ctrl_out_channels = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ base_in_channels = base_out_channels
+ base_out_channels = base_block_out_channels[i]
+ ctrl_in_channels = ctrl_out_channels
+ ctrl_out_channels = block_out_channels[i]
+ has_crossattn = "CrossAttn" in down_block_type
+ is_final_block = i == len(down_block_types) - 1
+
+ self.down_blocks.append(
+ get_down_block_adapter(
+ base_in_channels=base_in_channels,
+ base_out_channels=base_out_channels,
+ ctrl_in_channels=ctrl_in_channels,
+ ctrl_out_channels=ctrl_out_channels,
+ temb_channels=time_embedding_dim,
+ max_norm_num_groups=max_norm_num_groups,
+ has_crossattn=has_crossattn,
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ num_attention_heads=num_attention_heads[i],
+ cross_attention_dim=cross_attention_dim[i],
+ add_downsample=not is_final_block,
+ upcast_attention=upcast_attention,
+ )
+ )
+
+ # mid
+ self.mid_block = get_mid_block_adapter(
+ base_channels=base_block_out_channels[-1],
+ ctrl_channels=block_out_channels[-1],
+ temb_channels=time_embedding_dim,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ num_attention_heads=num_attention_heads[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ upcast_attention=upcast_attention,
+ )
+
+ # up
+ # The skip connection channels are the output of the conv_in and of all the down subblocks
+ ctrl_skip_channels = [block_out_channels[0]]
+ for i, out_channels in enumerate(block_out_channels):
+ number_of_subblocks = (
+ 3 if i < len(block_out_channels) - 1 else 2
+ ) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler
+ ctrl_skip_channels.extend([out_channels] * number_of_subblocks)
+
+ reversed_base_block_out_channels = list(reversed(base_block_out_channels))
+
+ base_out_channels = reversed_base_block_out_channels[0]
+ for i in range(len(down_block_types)):
+ prev_base_output_channel = base_out_channels
+ base_out_channels = reversed_base_block_out_channels[i]
+ ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)]
+
+ self.up_connections.append(
+ get_up_block_adapter(
+ out_channels=base_out_channels,
+ prev_output_channel=prev_base_output_channel,
+ ctrl_skip_channels=ctrl_skip_channels_,
+ )
+ )
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ size_ratio: Optional[float] = None,
+ block_out_channels: Optional[List[int]] = None,
+ num_attention_heads: Optional[List[int]] = None,
+ learn_time_embedding: bool = False,
+ time_embedding_mix: int = 1.0,
+ conditioning_channels: int = 3,
+ conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
+ ):
+ r"""
+ Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model we want to control. The dimensions of the ControlNetXSAdapter will be adapted to it.
+ size_ratio (float, *optional*, defaults to `None`):
+ When given, block_out_channels is set to a fraction of the base model's block_out_channels. Either this
+ or `block_out_channels` must be given.
+ block_out_channels (`List[int]`, *optional*, defaults to `None`):
+ Down blocks output channels in control model. Either this or `size_ratio` must be given.
+ num_attention_heads (`List[int]`, *optional*, defaults to `None`):
+ The dimension of the attention heads. The naming seems a bit confusing and it is, see
+ https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
+ learn_time_embedding (`bool`, defaults to `False`):
+ Whether the `ControlNetXSAdapter` should learn a time embedding.
+ time_embedding_mix (`float`, defaults to 1.0):
+ If 0, then only the control adapter's time embedding is used. If 1, then only the base unet's time
+ embedding is used. Otherwise, both are combined.
+ conditioning_channels (`int`, defaults to 3):
+ Number of channels of conditioning input (e.g. an image)
+ conditioning_channel_order (`str`, defaults to `"rgb"`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `controlnet_cond_embedding` layer.
+ """
+
+ # Check input
+ fixed_size = block_out_channels is not None
+ relative_size = size_ratio is not None
+ if not (fixed_size ^ relative_size):
+ raise ValueError(
+ "Pass exactly one of `block_out_channels` (for absolute sizing) or `size_ratio` (for relative sizing)."
+ )
+
+ # Create model
+ block_out_channels = block_out_channels or [int(b * size_ratio) for b in unet.config.block_out_channels]
+ if num_attention_heads is None:
+ # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
+ num_attention_heads = unet.config.attention_head_dim
+
+ model = cls(
+ conditioning_channels=conditioning_channels,
+ conditioning_channel_order=conditioning_channel_order,
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ time_embedding_mix=time_embedding_mix,
+ learn_time_embedding=learn_time_embedding,
+ num_attention_heads=num_attention_heads,
+ block_out_channels=block_out_channels,
+ base_block_out_channels=unet.config.block_out_channels,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ down_block_types=unet.config.down_block_types,
+ sample_size=unet.config.sample_size,
+ transformer_layers_per_block=unet.config.transformer_layers_per_block,
+ upcast_attention=unet.config.upcast_attention,
+ max_norm_num_groups=unet.config.norm_num_groups,
+ )
+
+ # ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel
+ model.to(unet.dtype)
+
+ return model
+
+ def forward(self, *args, **kwargs):
+ raise ValueError(
+ "A ControlNetXSAdapter cannot be run by itself. Use it together with a UNet2DConditionModel to instantiate a UNetControlNetXSModel."
+ )
+
+
+class UNetControlNetXSModel(ModelMixin, ConfigMixin):
+ r"""
+ A UNet fused with a ControlNet-XS adapter model
+
+ This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic
+ methods implemented for all models (such as downloading or saving).
+
+ `UNetControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are
+ compatible with StableDiffusion.
+
+ It's parameters are either passed to the underlying `UNet2DConditionModel` or used exactly like in
+ `ControlNetXSAdapter` . See their documentation for details.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ # unet configs
+ sample_size: Optional[int] = 96,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ norm_num_groups: Optional[int] = 32,
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ num_attention_heads: Union[int, Tuple[int]] = 8,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ upcast_attention: bool = True,
+ time_cond_proj_dim: Optional[int] = None,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ # additional controlnet configs
+ time_embedding_mix: float = 1.0,
+ ctrl_conditioning_channels: int = 3,
+ ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
+ ctrl_conditioning_channel_order: str = "rgb",
+ ctrl_learn_time_embedding: bool = False,
+ ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16),
+ ctrl_num_attention_heads: Union[int, Tuple[int]] = 4,
+ ctrl_max_norm_num_groups: int = 32,
+ ):
+ super().__init__()
+
+ if time_embedding_mix < 0 or time_embedding_mix > 1:
+ raise ValueError("`time_embedding_mix` needs to be between 0 and 1.")
+ if time_embedding_mix < 1 and not ctrl_learn_time_embedding:
+ raise ValueError("To use `time_embedding_mix` < 1, `ctrl_learn_time_embedding` must be `True`")
+
+ if addition_embed_type is not None and addition_embed_type != "text_time":
+ raise ValueError(
+ "As `UNetControlNetXSModel` currently only supports StableDiffusion and StableDiffusion-XL, `addition_embed_type` must be `None` or `'text_time'`."
+ )
+
+ if not isinstance(transformer_layers_per_block, (list, tuple)):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+ if not isinstance(cross_attention_dim, (list, tuple)):
+ cross_attention_dim = [cross_attention_dim] * len(down_block_types)
+ if not isinstance(num_attention_heads, (list, tuple)):
+ num_attention_heads = [num_attention_heads] * len(down_block_types)
+ if not isinstance(ctrl_num_attention_heads, (list, tuple)):
+ ctrl_num_attention_heads = [ctrl_num_attention_heads] * len(down_block_types)
+
+ base_num_attention_heads = num_attention_heads
+
+ self.in_channels = 4
+
+ # # Input
+ self.base_conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1)
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
+ conditioning_embedding_channels=ctrl_block_out_channels[0],
+ block_out_channels=ctrl_conditioning_embedding_out_channels,
+ conditioning_channels=ctrl_conditioning_channels,
+ )
+ self.ctrl_conv_in = nn.Conv2d(4, ctrl_block_out_channels[0], kernel_size=3, padding=1)
+ self.control_to_base_for_conv_in = make_zero_conv(ctrl_block_out_channels[0], block_out_channels[0])
+
+ # # Time
+ time_embed_input_dim = block_out_channels[0]
+ time_embed_dim = block_out_channels[0] * 4
+
+ self.base_time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.base_time_embedding = TimestepEmbedding(
+ time_embed_input_dim,
+ time_embed_dim,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+ self.ctrl_time_embedding = TimestepEmbedding(in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim)
+
+ if addition_embed_type is None:
+ self.base_add_time_proj = None
+ self.base_add_embedding = None
+ else:
+ self.base_add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.base_add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ # # Create down blocks
+ down_blocks = []
+ base_out_channels = block_out_channels[0]
+ ctrl_out_channels = ctrl_block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ base_in_channels = base_out_channels
+ base_out_channels = block_out_channels[i]
+ ctrl_in_channels = ctrl_out_channels
+ ctrl_out_channels = ctrl_block_out_channels[i]
+ has_crossattn = "CrossAttn" in down_block_type
+ is_final_block = i == len(down_block_types) - 1
+
+ down_blocks.append(
+ ControlNetXSCrossAttnDownBlock2D(
+ base_in_channels=base_in_channels,
+ base_out_channels=base_out_channels,
+ ctrl_in_channels=ctrl_in_channels,
+ ctrl_out_channels=ctrl_out_channels,
+ temb_channels=time_embed_dim,
+ norm_num_groups=norm_num_groups,
+ ctrl_max_norm_num_groups=ctrl_max_norm_num_groups,
+ has_crossattn=has_crossattn,
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ base_num_attention_heads=base_num_attention_heads[i],
+ ctrl_num_attention_heads=ctrl_num_attention_heads[i],
+ cross_attention_dim=cross_attention_dim[i],
+ add_downsample=not is_final_block,
+ upcast_attention=upcast_attention,
+ )
+ )
+
+ # # Create mid block
+ self.mid_block = ControlNetXSCrossAttnMidBlock2D(
+ base_channels=block_out_channels[-1],
+ ctrl_channels=ctrl_block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ norm_num_groups=norm_num_groups,
+ ctrl_max_norm_num_groups=ctrl_max_norm_num_groups,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ base_num_attention_heads=base_num_attention_heads[-1],
+ ctrl_num_attention_heads=ctrl_num_attention_heads[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ upcast_attention=upcast_attention,
+ )
+
+ # # Create up blocks
+ up_blocks = []
+ rev_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
+ rev_num_attention_heads = list(reversed(base_num_attention_heads))
+ rev_cross_attention_dim = list(reversed(cross_attention_dim))
+
+ # The skip connection channels are the output of the conv_in and of all the down subblocks
+ ctrl_skip_channels = [ctrl_block_out_channels[0]]
+ for i, out_channels in enumerate(ctrl_block_out_channels):
+ number_of_subblocks = (
+ 3 if i < len(ctrl_block_out_channels) - 1 else 2
+ ) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler
+ ctrl_skip_channels.extend([out_channels] * number_of_subblocks)
+
+ reversed_block_out_channels = list(reversed(block_out_channels))
+
+ out_channels = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = out_channels
+ out_channels = reversed_block_out_channels[i]
+ in_channels = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+ ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)]
+
+ has_crossattn = "CrossAttn" in up_block_type
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_blocks.append(
+ ControlNetXSCrossAttnUpBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ ctrl_skip_channels=ctrl_skip_channels_,
+ temb_channels=time_embed_dim,
+ resolution_idx=i,
+ has_crossattn=has_crossattn,
+ transformer_layers_per_block=rev_transformer_layers_per_block[i],
+ num_attention_heads=rev_num_attention_heads[i],
+ cross_attention_dim=rev_cross_attention_dim[i],
+ add_upsample=not is_final_block,
+ upcast_attention=upcast_attention,
+ norm_num_groups=norm_num_groups,
+ )
+ )
+
+ self.down_blocks = nn.ModuleList(down_blocks)
+ self.up_blocks = nn.ModuleList(up_blocks)
+
+ self.base_conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups)
+ self.base_conv_act = nn.SiLU()
+ self.base_conv_out = nn.Conv2d(block_out_channels[0], 4, kernel_size=3, padding=1)
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ controlnet: Optional[ControlNetXSAdapter] = None,
+ size_ratio: Optional[float] = None,
+ ctrl_block_out_channels: Optional[List[float]] = None,
+ time_embedding_mix: Optional[float] = None,
+ ctrl_optional_kwargs: Optional[Dict] = None,
+ ):
+ r"""
+ Instantiate a [`UNetControlNetXSModel`] from a [`UNet2DConditionModel`] and an optional [`ControlNetXSAdapter`]
+ .
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model we want to control.
+ controlnet (`ControlNetXSAdapter`):
+ The ConntrolNet-XS adapter with which the UNet will be fused. If none is given, a new ConntrolNet-XS
+ adapter will be created.
+ size_ratio (float, *optional*, defaults to `None`):
+ Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
+ ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`):
+ Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details,
+ where this parameter is called `block_out_channels`.
+ time_embedding_mix (`float`, *optional*, defaults to None):
+ Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
+ ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`):
+ Passed to the `init` of the new controlent if no controlent was given.
+ """
+ if controlnet is None:
+ controlnet = ControlNetXSAdapter.from_unet(
+ unet, size_ratio, ctrl_block_out_channels, **ctrl_optional_kwargs
+ )
+ else:
+ if any(
+ o is not None for o in (size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs)
+ ):
+ raise ValueError(
+ "When a controlnet is passed, none of these parameters should be passed: size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs."
+ )
+
+ # # get params
+ params_for_unet = [
+ "sample_size",
+ "down_block_types",
+ "up_block_types",
+ "block_out_channels",
+ "norm_num_groups",
+ "cross_attention_dim",
+ "transformer_layers_per_block",
+ "addition_embed_type",
+ "addition_time_embed_dim",
+ "upcast_attention",
+ "time_cond_proj_dim",
+ "projection_class_embeddings_input_dim",
+ ]
+ params_for_unet = {k: v for k, v in unet.config.items() if k in params_for_unet}
+ # The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why.
+ params_for_unet["num_attention_heads"] = unet.config.attention_head_dim
+
+ params_for_controlnet = [
+ "conditioning_channels",
+ "conditioning_embedding_out_channels",
+ "conditioning_channel_order",
+ "learn_time_embedding",
+ "block_out_channels",
+ "num_attention_heads",
+ "max_norm_num_groups",
+ ]
+ params_for_controlnet = {"ctrl_" + k: v for k, v in controlnet.config.items() if k in params_for_controlnet}
+ params_for_controlnet["time_embedding_mix"] = controlnet.config.time_embedding_mix
+
+ # # create model
+ model = cls.from_config({**params_for_unet, **params_for_controlnet})
+
+ # # load weights
+ # from unet
+ modules_from_unet = [
+ "time_embedding",
+ "conv_in",
+ "conv_norm_out",
+ "conv_out",
+ ]
+ for m in modules_from_unet:
+ getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict())
+
+ optional_modules_from_unet = [
+ "add_time_proj",
+ "add_embedding",
+ ]
+ for m in optional_modules_from_unet:
+ if hasattr(unet, m) and getattr(unet, m) is not None:
+ getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict())
+
+ # from controlnet
+ model.controlnet_cond_embedding.load_state_dict(controlnet.controlnet_cond_embedding.state_dict())
+ model.ctrl_conv_in.load_state_dict(controlnet.conv_in.state_dict())
+ if controlnet.time_embedding is not None:
+ model.ctrl_time_embedding.load_state_dict(controlnet.time_embedding.state_dict())
+ model.control_to_base_for_conv_in.load_state_dict(controlnet.control_to_base_for_conv_in.state_dict())
+
+ # from both
+ model.down_blocks = nn.ModuleList(
+ ControlNetXSCrossAttnDownBlock2D.from_modules(b, c)
+ for b, c in zip(unet.down_blocks, controlnet.down_blocks)
+ )
+ model.mid_block = ControlNetXSCrossAttnMidBlock2D.from_modules(unet.mid_block, controlnet.mid_block)
+ model.up_blocks = nn.ModuleList(
+ ControlNetXSCrossAttnUpBlock2D.from_modules(b, c)
+ for b, c in zip(unet.up_blocks, controlnet.up_connections)
+ )
+
+ # ensure that the UNetControlNetXSModel is the same dtype as the UNet2DConditionModel
+ model.to(unet.dtype)
+
+ return model
+
+ def freeze_unet_params(self) -> None:
+ """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
+ tuning."""
+ # Freeze everything
+ for param in self.parameters():
+ param.requires_grad = True
+
+ # Unfreeze ControlNetXSAdapter
+ base_parts = [
+ "base_time_proj",
+ "base_time_embedding",
+ "base_add_time_proj",
+ "base_add_embedding",
+ "base_conv_in",
+ "base_conv_norm_out",
+ "base_conv_act",
+ "base_conv_out",
+ ]
+ base_parts = [getattr(self, part) for part in base_parts if getattr(self, part) is not None]
+ for part in base_parts:
+ for param in part.parameters():
+ param.requires_grad = False
+
+ for d in self.down_blocks:
+ d.freeze_base_params()
+ self.mid_block.freeze_base_params()
+ for u in self.up_blocks:
+ u.freeze_base_params()
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ setattr(upsample_block, k, None)
+
+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ sample: FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: Optional[torch.Tensor] = None,
+ conditioning_scale: Optional[float] = 1.0,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ return_dict: bool = True,
+ apply_control: bool = True,
+ ) -> Union[ControlNetXSOutput, Tuple]:
+ """
+ The [`ControlNetXSModel`] forward method.
+
+ Args:
+ sample (`FloatTensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ controlnet_cond (`FloatTensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ How much the control model affects the base model outputs.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
+ embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
+ added_cond_kwargs (`dict`):
+ Additional conditions for the Stable Diffusion XL UNet.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
+ apply_control (`bool`, defaults to `True`):
+ If `False`, the input is run only through the base model.
+
+ Returns:
+ [`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+ """
+
+ # check channel order
+ if self.config.ctrl_conditioning_channel_order == "bgr":
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.base_time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ if self.config.ctrl_learn_time_embedding and apply_control:
+ ctrl_temb = self.ctrl_time_embedding(t_emb, timestep_cond)
+ base_temb = self.base_time_embedding(t_emb, timestep_cond)
+ interpolation_param = self.config.time_embedding_mix**0.3
+
+ temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param)
+ else:
+ temb = self.base_time_embedding(t_emb)
+
+ # added time & text embeddings
+ aug_emb = None
+
+ if self.config.addition_embed_type is None:
+ pass
+ elif self.config.addition_embed_type == "text_time":
+ # SDXL - style
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.base_add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(temb.dtype)
+ aug_emb = self.base_add_embedding(add_embeds)
+ else:
+ raise ValueError(
+ f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.config.addition_embed_type} is currently not supported."
+ )
+
+ temb = temb + aug_emb if aug_emb is not None else temb
+
+ # text embeddings
+ cemb = encoder_hidden_states
+
+ # Preparation
+ h_ctrl = h_base = sample
+ hs_base, hs_ctrl = [], []
+
+ # Cross Control
+ guided_hint = self.controlnet_cond_embedding(controlnet_cond)
+
+ # 1 - conv in & down
+
+ h_base = self.base_conv_in(h_base)
+ h_ctrl = self.ctrl_conv_in(h_ctrl)
+ if guided_hint is not None:
+ h_ctrl += guided_hint
+ if apply_control:
+ h_base = h_base + self.control_to_base_for_conv_in(h_ctrl) * conditioning_scale # add ctrl -> base
+
+ hs_base.append(h_base)
+ hs_ctrl.append(h_ctrl)
+
+ for down in self.down_blocks:
+ h_base, h_ctrl, residual_hb, residual_hc = down(
+ hidden_states_base=h_base,
+ hidden_states_ctrl=h_ctrl,
+ temb=temb,
+ encoder_hidden_states=cemb,
+ conditioning_scale=conditioning_scale,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ apply_control=apply_control,
+ )
+ hs_base.extend(residual_hb)
+ hs_ctrl.extend(residual_hc)
+
+ # 2 - mid
+ h_base, h_ctrl = self.mid_block(
+ hidden_states_base=h_base,
+ hidden_states_ctrl=h_ctrl,
+ temb=temb,
+ encoder_hidden_states=cemb,
+ conditioning_scale=conditioning_scale,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ apply_control=apply_control,
+ )
+
+ # 3 - up
+ for up in self.up_blocks:
+ n_resnets = len(up.resnets)
+ skips_hb = hs_base[-n_resnets:]
+ skips_hc = hs_ctrl[-n_resnets:]
+ hs_base = hs_base[:-n_resnets]
+ hs_ctrl = hs_ctrl[:-n_resnets]
+ h_base = up(
+ hidden_states=h_base,
+ res_hidden_states_tuple_base=skips_hb,
+ res_hidden_states_tuple_ctrl=skips_hc,
+ temb=temb,
+ encoder_hidden_states=cemb,
+ conditioning_scale=conditioning_scale,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ apply_control=apply_control,
+ )
+
+ # 4 - conv out
+ h_base = self.base_conv_norm_out(h_base)
+ h_base = self.base_conv_act(h_base)
+ h_base = self.base_conv_out(h_base)
+
+ if not return_dict:
+ return (h_base,)
+
+ return ControlNetXSOutput(sample=h_base)
+
+
+class ControlNetXSCrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ base_in_channels: int,
+ base_out_channels: int,
+ ctrl_in_channels: int,
+ ctrl_out_channels: int,
+ temb_channels: int,
+ norm_num_groups: int = 32,
+ ctrl_max_norm_num_groups: int = 32,
+ has_crossattn=True,
+ transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1,
+ base_num_attention_heads: Optional[int] = 1,
+ ctrl_num_attention_heads: Optional[int] = 1,
+ cross_attention_dim: Optional[int] = 1024,
+ add_downsample: bool = True,
+ upcast_attention: Optional[bool] = False,
+ ):
+ super().__init__()
+ base_resnets = []
+ base_attentions = []
+ ctrl_resnets = []
+ ctrl_attentions = []
+ ctrl_to_base = []
+ base_to_ctrl = []
+
+ num_layers = 2 # only support sd + sdxl
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ base_in_channels = base_in_channels if i == 0 else base_out_channels
+ ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels
+
+ # Before the resnet/attention application, information is concatted from base to control.
+ # Concat doesn't require change in number of channels
+ base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels))
+
+ base_resnets.append(
+ ResnetBlock2D(
+ in_channels=base_in_channels,
+ out_channels=base_out_channels,
+ temb_channels=temb_channels,
+ groups=norm_num_groups,
+ )
+ )
+ ctrl_resnets.append(
+ ResnetBlock2D(
+ in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl
+ out_channels=ctrl_out_channels,
+ temb_channels=temb_channels,
+ groups=find_largest_factor(
+ ctrl_in_channels + base_in_channels, max_factor=ctrl_max_norm_num_groups
+ ),
+ groups_out=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
+ eps=1e-5,
+ )
+ )
+
+ if has_crossattn:
+ base_attentions.append(
+ Transformer2DModel(
+ base_num_attention_heads,
+ base_out_channels // base_num_attention_heads,
+ in_channels=base_out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ use_linear_projection=True,
+ upcast_attention=upcast_attention,
+ norm_num_groups=norm_num_groups,
+ )
+ )
+ ctrl_attentions.append(
+ Transformer2DModel(
+ ctrl_num_attention_heads,
+ ctrl_out_channels // ctrl_num_attention_heads,
+ in_channels=ctrl_out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ use_linear_projection=True,
+ upcast_attention=upcast_attention,
+ norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups),
+ )
+ )
+
+ # After the resnet/attention application, information is added from control to base
+ # Addition requires change in number of channels
+ ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
+
+ if add_downsample:
+ # Before the downsampler application, information is concatted from base to control
+ # Concat doesn't require change in number of channels
+ base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels))
+
+ self.base_downsamplers = Downsample2D(
+ base_out_channels, use_conv=True, out_channels=base_out_channels, name="op"
+ )
+ self.ctrl_downsamplers = Downsample2D(
+ ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op"
+ )
+
+ # After the downsampler application, information is added from control to base
+ # Addition requires change in number of channels
+ ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels))
+ else:
+ self.base_downsamplers = None
+ self.ctrl_downsamplers = None
+
+ self.base_resnets = nn.ModuleList(base_resnets)
+ self.ctrl_resnets = nn.ModuleList(ctrl_resnets)
+ self.base_attentions = nn.ModuleList(base_attentions) if has_crossattn else [None] * num_layers
+ self.ctrl_attentions = nn.ModuleList(ctrl_attentions) if has_crossattn else [None] * num_layers
+ self.base_to_ctrl = nn.ModuleList(base_to_ctrl)
+ self.ctrl_to_base = nn.ModuleList(ctrl_to_base)
+
+ self.gradient_checkpointing = False
+
+ @classmethod
+ def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: DownBlockControlNetXSAdapter):
+ # get params
+ def get_first_cross_attention(block):
+ return block.attentions[0].transformer_blocks[0].attn2
+
+ base_in_channels = base_downblock.resnets[0].in_channels
+ base_out_channels = base_downblock.resnets[0].out_channels
+ ctrl_in_channels = (
+ ctrl_downblock.resnets[0].in_channels - base_in_channels
+ ) # base channels are concatted to ctrl channels in init
+ ctrl_out_channels = ctrl_downblock.resnets[0].out_channels
+ temb_channels = base_downblock.resnets[0].time_emb_proj.in_features
+ num_groups = base_downblock.resnets[0].norm1.num_groups
+ ctrl_num_groups = ctrl_downblock.resnets[0].norm1.num_groups
+ if hasattr(base_downblock, "attentions"):
+ has_crossattn = True
+ transformer_layers_per_block = len(base_downblock.attentions[0].transformer_blocks)
+ base_num_attention_heads = get_first_cross_attention(base_downblock).heads
+ ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads
+ cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim
+ upcast_attention = get_first_cross_attention(base_downblock).upcast_attention
+ else:
+ has_crossattn = False
+ transformer_layers_per_block = None
+ base_num_attention_heads = None
+ ctrl_num_attention_heads = None
+ cross_attention_dim = None
+ upcast_attention = None
+ add_downsample = base_downblock.downsamplers is not None
+
+ # create model
+ model = cls(
+ base_in_channels=base_in_channels,
+ base_out_channels=base_out_channels,
+ ctrl_in_channels=ctrl_in_channels,
+ ctrl_out_channels=ctrl_out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=num_groups,
+ ctrl_max_norm_num_groups=ctrl_num_groups,
+ has_crossattn=has_crossattn,
+ transformer_layers_per_block=transformer_layers_per_block,
+ base_num_attention_heads=base_num_attention_heads,
+ ctrl_num_attention_heads=ctrl_num_attention_heads,
+ cross_attention_dim=cross_attention_dim,
+ add_downsample=add_downsample,
+ upcast_attention=upcast_attention,
+ )
+
+ # # load weights
+ model.base_resnets.load_state_dict(base_downblock.resnets.state_dict())
+ model.ctrl_resnets.load_state_dict(ctrl_downblock.resnets.state_dict())
+ if has_crossattn:
+ model.base_attentions.load_state_dict(base_downblock.attentions.state_dict())
+ model.ctrl_attentions.load_state_dict(ctrl_downblock.attentions.state_dict())
+ if add_downsample:
+ model.base_downsamplers.load_state_dict(base_downblock.downsamplers[0].state_dict())
+ model.ctrl_downsamplers.load_state_dict(ctrl_downblock.downsamplers.state_dict())
+ model.base_to_ctrl.load_state_dict(ctrl_downblock.base_to_ctrl.state_dict())
+ model.ctrl_to_base.load_state_dict(ctrl_downblock.ctrl_to_base.state_dict())
+
+ return model
+
+ def freeze_base_params(self) -> None:
+ """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
+ tuning."""
+ # Unfreeze everything
+ for param in self.parameters():
+ param.requires_grad = True
+
+ # Freeze base part
+ base_parts = [self.base_resnets]
+ if isinstance(self.base_attentions, nn.ModuleList): # attentions can be a list of Nones
+ base_parts.append(self.base_attentions)
+ if self.base_downsamplers is not None:
+ base_parts.append(self.base_downsamplers)
+ for part in base_parts:
+ for param in part.parameters():
+ param.requires_grad = False
+
+ def forward(
+ self,
+ hidden_states_base: FloatTensor,
+ temb: FloatTensor,
+ encoder_hidden_states: Optional[FloatTensor] = None,
+ hidden_states_ctrl: Optional[FloatTensor] = None,
+ conditioning_scale: Optional[float] = 1.0,
+ attention_mask: Optional[FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[FloatTensor] = None,
+ apply_control: bool = True,
+ ) -> Tuple[FloatTensor, FloatTensor, Tuple[FloatTensor, ...], Tuple[FloatTensor, ...]]:
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ h_base = hidden_states_base
+ h_ctrl = hidden_states_ctrl
+
+ base_output_states = ()
+ ctrl_output_states = ()
+
+ base_blocks = list(zip(self.base_resnets, self.base_attentions))
+ ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions))
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip(
+ base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base
+ ):
+ # concat base -> ctrl
+ if apply_control:
+ h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)
+
+ # apply base subblock
+ if self.training and self.gradient_checkpointing:
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ h_base = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(b_res),
+ h_base,
+ temb,
+ **ckpt_kwargs,
+ )
+ else:
+ h_base = b_res(h_base, temb)
+
+ if b_attn is not None:
+ h_base = b_attn(
+ h_base,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ # apply ctrl subblock
+ if apply_control:
+ if self.training and self.gradient_checkpointing:
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ h_ctrl = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(c_res),
+ h_ctrl,
+ temb,
+ **ckpt_kwargs,
+ )
+ else:
+ h_ctrl = c_res(h_ctrl, temb)
+ if c_attn is not None:
+ h_ctrl = c_attn(
+ h_ctrl,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ # add ctrl -> base
+ if apply_control:
+ h_base = h_base + c2b(h_ctrl) * conditioning_scale
+
+ base_output_states = base_output_states + (h_base,)
+ ctrl_output_states = ctrl_output_states + (h_ctrl,)
+
+ if self.base_downsamplers is not None: # if we have a base_downsampler, then also a ctrl_downsampler
+ b2c = self.base_to_ctrl[-1]
+ c2b = self.ctrl_to_base[-1]
+
+ # concat base -> ctrl
+ if apply_control:
+ h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)
+ # apply base subblock
+ h_base = self.base_downsamplers(h_base)
+ # apply ctrl subblock
+ if apply_control:
+ h_ctrl = self.ctrl_downsamplers(h_ctrl)
+ # add ctrl -> base
+ if apply_control:
+ h_base = h_base + c2b(h_ctrl) * conditioning_scale
+
+ base_output_states = base_output_states + (h_base,)
+ ctrl_output_states = ctrl_output_states + (h_ctrl,)
+
+ return h_base, h_ctrl, base_output_states, ctrl_output_states
+
+
+class ControlNetXSCrossAttnMidBlock2D(nn.Module):
+ def __init__(
+ self,
+ base_channels: int,
+ ctrl_channels: int,
+ temb_channels: Optional[int] = None,
+ norm_num_groups: int = 32,
+ ctrl_max_norm_num_groups: int = 32,
+ transformer_layers_per_block: int = 1,
+ base_num_attention_heads: Optional[int] = 1,
+ ctrl_num_attention_heads: Optional[int] = 1,
+ cross_attention_dim: Optional[int] = 1024,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+
+ # Before the midblock application, information is concatted from base to control.
+ # Concat doesn't require change in number of channels
+ self.base_to_ctrl = make_zero_conv(base_channels, base_channels)
+
+ self.base_midblock = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=base_channels,
+ temb_channels=temb_channels,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=base_num_attention_heads,
+ use_linear_projection=True,
+ upcast_attention=upcast_attention,
+ )
+
+ self.ctrl_midblock = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=ctrl_channels + base_channels,
+ out_channels=ctrl_channels,
+ temb_channels=temb_channels,
+ # number or norm groups must divide both in_channels and out_channels
+ resnet_groups=find_largest_factor(
+ gcd(ctrl_channels, ctrl_channels + base_channels), ctrl_max_norm_num_groups
+ ),
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=ctrl_num_attention_heads,
+ use_linear_projection=True,
+ upcast_attention=upcast_attention,
+ )
+
+ # After the midblock application, information is added from control to base
+ # Addition requires change in number of channels
+ self.ctrl_to_base = make_zero_conv(ctrl_channels, base_channels)
+
+ self.gradient_checkpointing = False
+
+ @classmethod
+ def from_modules(
+ cls,
+ base_midblock: UNetMidBlock2DCrossAttn,
+ ctrl_midblock: MidBlockControlNetXSAdapter,
+ ):
+ base_to_ctrl = ctrl_midblock.base_to_ctrl
+ ctrl_to_base = ctrl_midblock.ctrl_to_base
+ ctrl_midblock = ctrl_midblock.midblock
+
+ # get params
+ def get_first_cross_attention(midblock):
+ return midblock.attentions[0].transformer_blocks[0].attn2
+
+ base_channels = ctrl_to_base.out_channels
+ ctrl_channels = ctrl_to_base.in_channels
+ transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks)
+ temb_channels = base_midblock.resnets[0].time_emb_proj.in_features
+ num_groups = base_midblock.resnets[0].norm1.num_groups
+ ctrl_num_groups = ctrl_midblock.resnets[0].norm1.num_groups
+ base_num_attention_heads = get_first_cross_attention(base_midblock).heads
+ ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads
+ cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim
+ upcast_attention = get_first_cross_attention(base_midblock).upcast_attention
+
+ # create model
+ model = cls(
+ base_channels=base_channels,
+ ctrl_channels=ctrl_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=num_groups,
+ ctrl_max_norm_num_groups=ctrl_num_groups,
+ transformer_layers_per_block=transformer_layers_per_block,
+ base_num_attention_heads=base_num_attention_heads,
+ ctrl_num_attention_heads=ctrl_num_attention_heads,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ )
+
+ # load weights
+ model.base_to_ctrl.load_state_dict(base_to_ctrl.state_dict())
+ model.base_midblock.load_state_dict(base_midblock.state_dict())
+ model.ctrl_midblock.load_state_dict(ctrl_midblock.state_dict())
+ model.ctrl_to_base.load_state_dict(ctrl_to_base.state_dict())
+
+ return model
+
+ def freeze_base_params(self) -> None:
+ """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
+ tuning."""
+ # Unfreeze everything
+ for param in self.parameters():
+ param.requires_grad = True
+
+ # Freeze base part
+ for param in self.base_midblock.parameters():
+ param.requires_grad = False
+
+ def forward(
+ self,
+ hidden_states_base: FloatTensor,
+ temb: FloatTensor,
+ encoder_hidden_states: FloatTensor,
+ hidden_states_ctrl: Optional[FloatTensor] = None,
+ conditioning_scale: Optional[float] = 1.0,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ attention_mask: Optional[FloatTensor] = None,
+ encoder_attention_mask: Optional[FloatTensor] = None,
+ apply_control: bool = True,
+ ) -> Tuple[FloatTensor, FloatTensor]:
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ h_base = hidden_states_base
+ h_ctrl = hidden_states_ctrl
+
+ joint_args = {
+ "temb": temb,
+ "encoder_hidden_states": encoder_hidden_states,
+ "attention_mask": attention_mask,
+ "cross_attention_kwargs": cross_attention_kwargs,
+ "encoder_attention_mask": encoder_attention_mask,
+ }
+
+ if apply_control:
+ h_ctrl = torch.cat([h_ctrl, self.base_to_ctrl(h_base)], dim=1) # concat base -> ctrl
+ h_base = self.base_midblock(h_base, **joint_args) # apply base mid block
+ if apply_control:
+ h_ctrl = self.ctrl_midblock(h_ctrl, **joint_args) # apply ctrl mid block
+ h_base = h_base + self.ctrl_to_base(h_ctrl) * conditioning_scale # add ctrl -> base
+
+ return h_base, h_ctrl
+
+
+class ControlNetXSCrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ ctrl_skip_channels: List[int],
+ temb_channels: int,
+ norm_num_groups: int = 32,
+ resolution_idx: Optional[int] = None,
+ has_crossattn=True,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1024,
+ add_upsample: bool = True,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ ctrl_to_base = []
+
+ num_layers = 3 # only support sd + sdxl
+
+ self.has_cross_attention = has_crossattn
+ self.num_attention_heads = num_attention_heads
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels))
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ groups=norm_num_groups,
+ )
+ )
+
+ if has_crossattn:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ use_linear_projection=True,
+ upcast_attention=upcast_attention,
+ norm_num_groups=norm_num_groups,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions) if has_crossattn else [None] * num_layers
+ self.ctrl_to_base = nn.ModuleList(ctrl_to_base)
+
+ if add_upsample:
+ self.upsamplers = Upsample2D(out_channels, use_conv=True, out_channels=out_channels)
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ @classmethod
+ def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_upblock: UpBlockControlNetXSAdapter):
+ ctrl_to_base_skip_connections = ctrl_upblock.ctrl_to_base
+
+ # get params
+ def get_first_cross_attention(block):
+ return block.attentions[0].transformer_blocks[0].attn2
+
+ out_channels = base_upblock.resnets[0].out_channels
+ in_channels = base_upblock.resnets[-1].in_channels - out_channels
+ prev_output_channels = base_upblock.resnets[0].in_channels - out_channels
+ ctrl_skip_channelss = [c.in_channels for c in ctrl_to_base_skip_connections]
+ temb_channels = base_upblock.resnets[0].time_emb_proj.in_features
+ num_groups = base_upblock.resnets[0].norm1.num_groups
+ resolution_idx = base_upblock.resolution_idx
+ if hasattr(base_upblock, "attentions"):
+ has_crossattn = True
+ transformer_layers_per_block = len(base_upblock.attentions[0].transformer_blocks)
+ num_attention_heads = get_first_cross_attention(base_upblock).heads
+ cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim
+ upcast_attention = get_first_cross_attention(base_upblock).upcast_attention
+ else:
+ has_crossattn = False
+ transformer_layers_per_block = None
+ num_attention_heads = None
+ cross_attention_dim = None
+ upcast_attention = None
+ add_upsample = base_upblock.upsamplers is not None
+
+ # create model
+ model = cls(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channels,
+ ctrl_skip_channels=ctrl_skip_channelss,
+ temb_channels=temb_channels,
+ norm_num_groups=num_groups,
+ resolution_idx=resolution_idx,
+ has_crossattn=has_crossattn,
+ transformer_layers_per_block=transformer_layers_per_block,
+ num_attention_heads=num_attention_heads,
+ cross_attention_dim=cross_attention_dim,
+ add_upsample=add_upsample,
+ upcast_attention=upcast_attention,
+ )
+
+ # load weights
+ model.resnets.load_state_dict(base_upblock.resnets.state_dict())
+ if has_crossattn:
+ model.attentions.load_state_dict(base_upblock.attentions.state_dict())
+ if add_upsample:
+ model.upsamplers.load_state_dict(base_upblock.upsamplers[0].state_dict())
+ model.ctrl_to_base.load_state_dict(ctrl_to_base_skip_connections.state_dict())
+
+ return model
+
+ def freeze_base_params(self) -> None:
+ """Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine
+ tuning."""
+ # Unfreeze everything
+ for param in self.parameters():
+ param.requires_grad = True
+
+ # Freeze base part
+ base_parts = [self.resnets]
+ if isinstance(self.attentions, nn.ModuleList): # attentions can be a list of Nones
+ base_parts.append(self.attentions)
+ if self.upsamplers is not None:
+ base_parts.append(self.upsamplers)
+ for part in base_parts:
+ for param in part.parameters():
+ param.requires_grad = False
+
+ def forward(
+ self,
+ hidden_states: FloatTensor,
+ res_hidden_states_tuple_base: Tuple[FloatTensor, ...],
+ res_hidden_states_tuple_ctrl: Tuple[FloatTensor, ...],
+ temb: FloatTensor,
+ encoder_hidden_states: Optional[FloatTensor] = None,
+ conditioning_scale: Optional[float] = 1.0,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ attention_mask: Optional[FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ encoder_attention_mask: Optional[FloatTensor] = None,
+ apply_control: bool = True,
+ ) -> FloatTensor:
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ return apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_h_base,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+ else:
+ return hidden_states, res_h_base
+
+ for resnet, attn, c2b, res_h_base, res_h_ctrl in zip(
+ self.resnets,
+ self.attentions,
+ self.ctrl_to_base,
+ reversed(res_hidden_states_tuple_base),
+ reversed(res_hidden_states_tuple_ctrl),
+ ):
+ if apply_control:
+ hidden_states += c2b(res_h_ctrl) * conditioning_scale
+
+ hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base)
+ hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if attn is not None:
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if self.upsamplers is not None:
+ hidden_states = self.upsamplers(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+def make_zero_conv(in_channels, out_channels=None):
+ return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
+
+
+def find_largest_factor(number, max_factor):
+ factor = max_factor
+ if factor >= number:
+ return number
+ while factor != 0:
+ residual = number % factor
+ if residual == 0:
+ return factor
+ factor -= 1
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 85b1e4944ed2..ced520bb8204 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -472,6 +472,22 @@ def forward(self, image_embeds: torch.FloatTensor):
return self.norm(self.ff(image_embeds))
+class IPAdapterFaceIDImageProjection(nn.Module):
+ def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
+ super().__init__()
+ from .attention import FeedForward
+
+ self.num_tokens = num_tokens
+ self.cross_attention_dim = cross_attention_dim
+ self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
+ self.norm = nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds: torch.FloatTensor):
+ x = self.ff(image_embeds)
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
+ return self.norm(x)
+
+
class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
@@ -794,17 +810,15 @@ class IPAdapterPlusImageProjection(nn.Module):
"""Resampler of IP-Adapter Plus.
Args:
- ----
- embed_dims (int): The feature dimension. Defaults to 768.
- output_dims (int): The number of output channels, that is the same
- number of the channels in the
- `unet.config.cross_attention_dim`. Defaults to 1024.
- hidden_dims (int): The number of hidden channels. Defaults to 1280.
- depth (int): The number of blocks. Defaults to 8.
- dim_head (int): The number of head channels. Defaults to 64.
- heads (int): Parallel attention heads. Defaults to 16.
- num_queries (int): The number of queries. Defaults to 8.
- ffn_ratio (float): The expansion ratio of feedforward network hidden
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
+ that is the same
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
+ hidden_dims (int):
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
+ Defaults to 16. num_queries (int):
+ The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
+ of feedforward network hidden
layer channels. Defaults to 4.
"""
@@ -854,11 +868,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
- ----
x (torch.Tensor): Input Tensor.
-
Returns:
- -------
torch.Tensor: Output Tensor.
"""
latents = self.latents.repeat(x.size(0), 1, 1)
@@ -878,6 +889,119 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.norm_out(latents)
+class IPAdapterPlusImageProjectionBlock(nn.Module):
+ def __init__(
+ self,
+ embed_dims: int = 768,
+ dim_head: int = 64,
+ heads: int = 16,
+ ffn_ratio: float = 4,
+ ) -> None:
+ super().__init__()
+ from .attention import FeedForward
+
+ self.ln0 = nn.LayerNorm(embed_dims)
+ self.ln1 = nn.LayerNorm(embed_dims)
+ self.attn = Attention(
+ query_dim=embed_dims,
+ dim_head=dim_head,
+ heads=heads,
+ out_bias=False,
+ )
+ self.ff = nn.Sequential(
+ nn.LayerNorm(embed_dims),
+ FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
+ )
+
+ def forward(self, x, latents, residual):
+ encoder_hidden_states = self.ln0(x)
+ latents = self.ln1(latents)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
+ latents = self.attn(latents, encoder_hidden_states) + residual
+ latents = self.ff(latents) + latents
+ return latents
+
+
+class IPAdapterFaceIDPlusImageProjection(nn.Module):
+ """FacePerceiverResampler of IP-Adapter Plus.
+
+ Args:
+ embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
+ that is the same
+ number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
+ hidden_dims (int):
+ The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
+ to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
+ Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
+ ffn_ratio (float): The expansion ratio of feedforward network hidden
+ layer channels. Defaults to 4.
+ ffproj_ratio (float): The expansion ratio of feedforward network hidden
+ layer channels (for ID embeddings). Defaults to 4.
+ """
+
+ def __init__(
+ self,
+ embed_dims: int = 768,
+ output_dims: int = 768,
+ hidden_dims: int = 1280,
+ id_embeddings_dim: int = 512,
+ depth: int = 4,
+ dim_head: int = 64,
+ heads: int = 16,
+ num_tokens: int = 4,
+ num_queries: int = 8,
+ ffn_ratio: float = 4,
+ ffproj_ratio: int = 2,
+ ) -> None:
+ super().__init__()
+ from .attention import FeedForward
+
+ self.num_tokens = num_tokens
+ self.embed_dim = embed_dims
+ self.clip_embeds = None
+ self.shortcut = False
+ self.shortcut_scale = 1.0
+
+ self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
+ self.norm = nn.LayerNorm(embed_dims)
+
+ self.proj_in = nn.Linear(hidden_dims, embed_dims)
+
+ self.proj_out = nn.Linear(embed_dims, output_dims)
+ self.norm_out = nn.LayerNorm(output_dims)
+
+ self.layers = nn.ModuleList(
+ [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
+ )
+
+ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ id_embeds (torch.Tensor): Input Tensor (ID embeds).
+ Returns:
+ torch.Tensor: Output Tensor.
+ """
+ id_embeds = id_embeds.to(self.clip_embeds.dtype)
+ id_embeds = self.proj(id_embeds)
+ id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
+ id_embeds = self.norm(id_embeds)
+ latents = id_embeds
+
+ clip_embeds = self.proj_in(self.clip_embeds)
+ x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
+
+ for block in self.layers:
+ residual = latents
+ latents = block(x, latents, residual)
+
+ latents = self.proj_out(latents)
+ out = self.norm_out(latents)
+ if self.shortcut:
+ out = id_embeds + self.shortcut_scale * out
+ return out
+
+
class MultiIPAdapterImageProjection(nn.Module):
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
super().__init__()
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index 95ae6f4fc4ea..c1fdff8ab356 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -699,6 +699,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
+ force_hooks=True,
+ strict=True,
)
except AttributeError as e:
# When using accelerate loading, we do not have the ability to load the state
diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py
index 88c7a01be6bf..adda53a11481 100644
--- a/src/diffusers/models/resnet.py
+++ b/src/diffusers/models/resnet.py
@@ -202,8 +202,8 @@ class ResnetBlock2D(nn.Module):
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
- By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift"
- for a stronger conditioning with scale and shift.
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a
+ stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
diff --git a/src/diffusers/models/transformers/dual_transformer_2d.py b/src/diffusers/models/transformers/dual_transformer_2d.py
index 96849bd28bb1..e2f1b8538ca0 100644
--- a/src/diffusers/models/transformers/dual_transformer_2d.py
+++ b/src/diffusers/models/transformers/dual_transformer_2d.py
@@ -120,7 +120,8 @@ def forward(
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+ Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py
index 0658a7daa241..296449889e36 100644
--- a/src/diffusers/models/transformers/transformer_2d.py
+++ b/src/diffusers/models/transformers/transformer_2d.py
@@ -102,6 +102,8 @@ def __init__(
interpolation_scale: float = None,
):
super().__init__()
+
+ # Validate inputs.
if patch_size is not None:
if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
raise NotImplementedError(
@@ -112,10 +114,16 @@ def __init__(
f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
)
+ # Set some common variables used across the board.
self.use_linear_projection = use_linear_projection
+ self.interpolation_scale = interpolation_scale
+ self.caption_channels = caption_channels
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
- inner_dim = num_attention_heads * attention_head_dim
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.gradient_checkpointing = False
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
@@ -150,104 +158,167 @@ def __init__(
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
)
- # 2. Define input layers
+ # 2. Initialize the right blocks.
+ # These functions follow a common structure:
+ # a. Initialize the input blocks. b. Initialize the transformer blocks.
+ # c. Initialize the output blocks and other projection blocks when necessary.
if self.is_input_continuous:
- self.in_channels = in_channels
-
- self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
- if use_linear_projection:
- self.proj_in = nn.Linear(in_channels, inner_dim)
- else:
- self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ self._init_continuous_input(norm_type=norm_type)
elif self.is_input_vectorized:
- assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
- assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
+ self._init_vectorized_inputs(norm_type=norm_type)
+ elif self.is_input_patches:
+ self._init_patched_inputs(norm_type=norm_type)
- self.height = sample_size
- self.width = sample_size
- self.num_vector_embeds = num_vector_embeds
- self.num_latent_pixels = self.height * self.width
+ def _init_continuous_input(self, norm_type):
+ self.norm = torch.nn.GroupNorm(
+ num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True
+ )
+ if self.use_linear_projection:
+ self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim)
+ else:
+ self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0)
- self.latent_image_embedding = ImagePositionalEmbeddings(
- num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
- )
- elif self.is_input_patches:
- assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ self.inner_dim,
+ self.config.num_attention_heads,
+ self.config.attention_head_dim,
+ dropout=self.config.dropout,
+ cross_attention_dim=self.config.cross_attention_dim,
+ activation_fn=self.config.activation_fn,
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
+ attention_bias=self.config.attention_bias,
+ only_cross_attention=self.config.only_cross_attention,
+ double_self_attention=self.config.double_self_attention,
+ upcast_attention=self.config.upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
+ norm_eps=self.config.norm_eps,
+ attention_type=self.config.attention_type,
+ )
+ for _ in range(self.config.num_layers)
+ ]
+ )
- self.height = sample_size
- self.width = sample_size
+ if self.use_linear_projection:
+ self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels)
+ else:
+ self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0)
- self.patch_size = patch_size
- interpolation_scale = (
- interpolation_scale if interpolation_scale is not None else max(self.config.sample_size // 64, 1)
- )
- self.pos_embed = PatchEmbed(
- height=sample_size,
- width=sample_size,
- patch_size=patch_size,
- in_channels=in_channels,
- embed_dim=inner_dim,
- interpolation_scale=interpolation_scale,
- )
+ def _init_vectorized_inputs(self, norm_type):
+ assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
+ assert (
+ self.config.num_vector_embeds is not None
+ ), "Transformer2DModel over discrete input must provide num_embed"
+
+ self.height = self.config.sample_size
+ self.width = self.config.sample_size
+ self.num_latent_pixels = self.height * self.width
+
+ self.latent_image_embedding = ImagePositionalEmbeddings(
+ num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width
+ )
- # 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
- inner_dim,
- num_attention_heads,
- attention_head_dim,
- dropout=dropout,
- cross_attention_dim=cross_attention_dim,
- activation_fn=activation_fn,
- num_embeds_ada_norm=num_embeds_ada_norm,
- attention_bias=attention_bias,
- only_cross_attention=only_cross_attention,
- double_self_attention=double_self_attention,
- upcast_attention=upcast_attention,
+ self.inner_dim,
+ self.config.num_attention_heads,
+ self.config.attention_head_dim,
+ dropout=self.config.dropout,
+ cross_attention_dim=self.config.cross_attention_dim,
+ activation_fn=self.config.activation_fn,
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
+ attention_bias=self.config.attention_bias,
+ only_cross_attention=self.config.only_cross_attention,
+ double_self_attention=self.config.double_self_attention,
+ upcast_attention=self.config.upcast_attention,
norm_type=norm_type,
- norm_elementwise_affine=norm_elementwise_affine,
- norm_eps=norm_eps,
- attention_type=attention_type,
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
+ norm_eps=self.config.norm_eps,
+ attention_type=self.config.attention_type,
)
- for d in range(num_layers)
+ for _ in range(self.config.num_layers)
]
)
- # 4. Define output layers
- self.out_channels = in_channels if out_channels is None else out_channels
- if self.is_input_continuous:
- # TODO: should use out_channels for continuous projections
- if use_linear_projection:
- self.proj_out = nn.Linear(inner_dim, in_channels)
- else:
- self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
- elif self.is_input_vectorized:
- self.norm_out = nn.LayerNorm(inner_dim)
- self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
- elif self.is_input_patches and norm_type != "ada_norm_single":
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
- self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
- self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
- elif self.is_input_patches and norm_type == "ada_norm_single":
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
- self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
- self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
-
- # 5. PixArt-Alpha blocks.
+ self.norm_out = nn.LayerNorm(self.inner_dim)
+ self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1)
+
+ def _init_patched_inputs(self, norm_type):
+ assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
+
+ self.height = self.config.sample_size
+ self.width = self.config.sample_size
+
+ self.patch_size = self.config.patch_size
+ interpolation_scale = (
+ self.config.interpolation_scale
+ if self.config.interpolation_scale is not None
+ else max(self.config.sample_size // 64, 1)
+ )
+ self.pos_embed = PatchEmbed(
+ height=self.config.sample_size,
+ width=self.config.sample_size,
+ patch_size=self.config.patch_size,
+ in_channels=self.in_channels,
+ embed_dim=self.inner_dim,
+ interpolation_scale=interpolation_scale,
+ )
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ self.inner_dim,
+ self.config.num_attention_heads,
+ self.config.attention_head_dim,
+ dropout=self.config.dropout,
+ cross_attention_dim=self.config.cross_attention_dim,
+ activation_fn=self.config.activation_fn,
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
+ attention_bias=self.config.attention_bias,
+ only_cross_attention=self.config.only_cross_attention,
+ double_self_attention=self.config.double_self_attention,
+ upcast_attention=self.config.upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
+ norm_eps=self.config.norm_eps,
+ attention_type=self.config.attention_type,
+ )
+ for _ in range(self.config.num_layers)
+ ]
+ )
+
+ if self.config.norm_type != "ada_norm_single":
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
+ self.proj_out_2 = nn.Linear(
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
+ )
+ elif self.config.norm_type == "ada_norm_single":
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
+ self.proj_out = nn.Linear(
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
+ )
+
+ # PixArt-Alpha blocks.
self.adaln_single = None
self.use_additional_conditions = False
- if norm_type == "ada_norm_single":
+ if self.config.norm_type == "ada_norm_single":
self.use_additional_conditions = self.config.sample_size == 128
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
# additional conditions until we find better name
- self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
+ self.adaln_single = AdaLayerNormSingle(
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
+ )
self.caption_projection = None
- if caption_channels is not None:
- self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
-
- self.gradient_checkpointing = False
+ if self.caption_channels is not None:
+ self.caption_projection = PixArtAlphaTextProjection(
+ in_features=self.caption_channels, hidden_size=self.inner_dim
+ )
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
@@ -331,41 +402,18 @@ def forward(
# 1. Input
if self.is_input_continuous:
- batch, _, height, width = hidden_states.shape
+ batch_size, _, height, width = hidden_states.shape
residual = hidden_states
-
- hidden_states = self.norm(hidden_states)
- if not self.use_linear_projection:
- hidden_states = self.proj_in(hidden_states)
- inner_dim = hidden_states.shape[1]
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
- else:
- inner_dim = hidden_states.shape[1]
- hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
- hidden_states = self.proj_in(hidden_states)
-
+ hidden_states, inner_dim = self._operate_on_continuous_inputs(hidden_states)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
- hidden_states = self.pos_embed(hidden_states)
-
- if self.adaln_single is not None:
- if self.use_additional_conditions and added_cond_kwargs is None:
- raise ValueError(
- "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
- )
- batch_size = hidden_states.shape[0]
- timestep, embedded_timestep = self.adaln_single(
- timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
- )
+ hidden_states, encoder_hidden_states, timestep, embedded_timestep = self._operate_on_patched_inputs(
+ hidden_states, encoder_hidden_states, timestep, added_cond_kwargs
+ )
# 2. Blocks
- if self.caption_projection is not None:
- batch_size = hidden_states.shape[0]
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
-
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
@@ -403,51 +451,116 @@ def custom_forward(*inputs):
# 3. Output
if self.is_input_continuous:
- if not self.use_linear_projection:
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
- hidden_states = self.proj_out(hidden_states)
- else:
- hidden_states = self.proj_out(hidden_states)
- hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
-
- output = hidden_states + residual
+ output = self._get_output_for_continuous_inputs(
+ hidden_states=hidden_states,
+ residual=residual,
+ batch_size=batch_size,
+ height=height,
+ width=width,
+ inner_dim=inner_dim,
+ )
elif self.is_input_vectorized:
- hidden_states = self.norm_out(hidden_states)
- logits = self.out(hidden_states)
- # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
- logits = logits.permute(0, 2, 1)
+ output = self._get_output_for_vectorized_inputs(hidden_states)
+ elif self.is_input_patches:
+ output = self._get_output_for_patched_inputs(
+ hidden_states=hidden_states,
+ timestep=timestep,
+ class_labels=class_labels,
+ embedded_timestep=embedded_timestep,
+ height=height,
+ width=width,
+ )
- # log(p(x_0))
- output = F.log_softmax(logits.double(), dim=1).float()
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
- if self.is_input_patches:
- if self.config.norm_type != "ada_norm_single":
- conditioning = self.transformer_blocks[0].norm1.emb(
- timestep, class_labels, hidden_dtype=hidden_states.dtype
+ def _operate_on_continuous_inputs(self, hidden_states):
+ batch, _, height, width = hidden_states.shape
+ hidden_states = self.norm(hidden_states)
+
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ return hidden_states, inner_dim
+
+ def _operate_on_patched_inputs(self, hidden_states, encoder_hidden_states, timestep, added_cond_kwargs):
+ batch_size = hidden_states.shape[0]
+ hidden_states = self.pos_embed(hidden_states)
+ embedded_timestep = None
+
+ if self.adaln_single is not None:
+ if self.use_additional_conditions and added_cond_kwargs is None:
+ raise ValueError(
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
- shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
- hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
- hidden_states = self.proj_out_2(hidden_states)
- elif self.config.norm_type == "ada_norm_single":
- shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
- hidden_states = self.norm_out(hidden_states)
- # Modulation
- hidden_states = hidden_states * (1 + scale) + shift
- hidden_states = self.proj_out(hidden_states)
- hidden_states = hidden_states.squeeze(1)
-
- # unpatchify
- if self.adaln_single is None:
- height = width = int(hidden_states.shape[1] ** 0.5)
- hidden_states = hidden_states.reshape(
- shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
+ timestep, embedded_timestep = self.adaln_single(
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
- output = hidden_states.reshape(
- shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
+
+ if self.caption_projection is not None:
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+
+ return hidden_states, encoder_hidden_states, timestep, embedded_timestep
+
+ def _get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim):
+ if not self.use_linear_projection:
+ hidden_states = (
+ hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch_size, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
)
- if not return_dict:
- return (output,)
+ output = hidden_states + residual
+ return output
- return Transformer2DModelOutput(sample=output)
+ def _get_output_for_vectorized_inputs(self, hidden_states):
+ hidden_states = self.norm_out(hidden_states)
+ logits = self.out(hidden_states)
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
+ logits = logits.permute(0, 2, 1)
+ # log(p(x_0))
+ output = F.log_softmax(logits.double(), dim=1).float()
+ return output
+
+ def _get_output_for_patched_inputs(
+ self, hidden_states, timestep, class_labels, embedded_timestep, height=None, width=None
+ ):
+ if self.config.norm_type != "ada_norm_single":
+ conditioning = self.transformer_blocks[0].norm1.emb(
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ hidden_states = self.proj_out_2(hidden_states)
+ elif self.config.norm_type == "ada_norm_single":
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states)
+ # Modulation
+ hidden_states = hidden_states * (1 + scale) + shift
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.squeeze(1)
+
+ # unpatchify
+ if self.adaln_single is None:
+ height = width = int(hidden_states.shape[1] ** 0.5)
+ hidden_states = hidden_states.reshape(
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
+ )
+ return output
diff --git a/src/diffusers/models/transformers/transformer_temporal.py b/src/diffusers/models/transformers/transformer_temporal.py
index a35aa4671e6c..c2d490f3d046 100644
--- a/src/diffusers/models/transformers/transformer_temporal.py
+++ b/src/diffusers/models/transformers/transformer_temporal.py
@@ -294,8 +294,8 @@ def forward(
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
images, 0 indicates that the input contains video frames.
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
- tuple.
+ Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a
+ plain tuple.
Returns:
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py
index d54630376961..ef75fad25e44 100644
--- a/src/diffusers/models/unets/unet_2d_blocks.py
+++ b/src/diffusers/models/unets/unet_2d_blocks.py
@@ -746,6 +746,7 @@ def __init__(
self,
in_channels: int,
temb_channels: int,
+ out_channels: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
@@ -753,6 +754,7 @@ def __init__(
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
+ resnet_groups_out: Optional[int] = None,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
@@ -764,6 +766,10 @@ def __init__(
):
super().__init__()
+ out_channels = out_channels or in_channels
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
@@ -772,14 +778,17 @@ def __init__(
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+ resnet_groups_out = resnet_groups_out or resnet_groups
+
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
- out_channels=in_channels,
+ out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
+ groups_out=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
@@ -794,11 +803,11 @@ def __init__(
attentions.append(
Transformer2DModel(
num_attention_heads,
- in_channels // num_attention_heads,
- in_channels=in_channels,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
+ norm_num_groups=resnet_groups_out,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
@@ -808,8 +817,8 @@ def __init__(
attentions.append(
DualTransformer2DModel(
num_attention_heads,
- in_channels // num_attention_heads,
- in_channels=in_channels,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
@@ -817,11 +826,11 @@ def __init__(
)
resnets.append(
ResnetBlock2D(
- in_channels=in_channels,
- out_channels=in_channels,
+ in_channels=out_channels,
+ out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
- groups=resnet_groups,
+ groups=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py
index 9a710919d067..34327e1049c5 100644
--- a/src/diffusers/models/unets/unet_2d_condition.py
+++ b/src/diffusers/models/unets/unet_2d_condition.py
@@ -865,8 +865,8 @@ def disable_freeu(self):
def fuse_qkv_projections(self):
"""
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
@@ -1093,8 +1093,8 @@ def forward(
Returns:
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
- If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
- a `tuple` is returned where the first element is the sample tensor.
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
diff --git a/src/diffusers/models/unets/unet_2d_condition_flax.py b/src/diffusers/models/unets/unet_2d_condition_flax.py
index a5ec2875ca0e..edbbcbaeda73 100644
--- a/src/diffusers/models/unets/unet_2d_condition_flax.py
+++ b/src/diffusers/models/unets/unet_2d_condition_flax.py
@@ -76,7 +76,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
- Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped.
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer
+ is skipped.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
@@ -350,15 +351,15 @@ def __call__(
mid_block_additional_residual: (`torch.Tensor`, *optional*):
A tensor that if specified is added to the residual of the middle unet block.
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
- plain tuple.
+ Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of
+ a plain tuple.
train (`bool`, *optional*, defaults to `False`):
Use deterministic functions and disable dropout when not training.
Returns:
[`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
- [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
- When returning a tuple, the first element is the sample tensor.
+ [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is the sample tensor.
"""
# 1. time
if not isinstance(timesteps, jnp.ndarray):
diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py
index a827b4ddc5a7..6c353c425911 100644
--- a/src/diffusers/models/unets/unet_3d_condition.py
+++ b/src/diffusers/models/unets/unet_3d_condition.py
@@ -511,8 +511,8 @@ def disable_freeu(self):
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py
index 5c5c6a2cc5ec..0a5f71ed0029 100644
--- a/src/diffusers/models/unets/unet_i2vgen_xl.py
+++ b/src/diffusers/models/unets/unet_i2vgen_xl.py
@@ -99,8 +99,8 @@ def forward(
class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
- I2VGenXL UNet. It is a conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep
- and returns a sample-shaped output.
+ I2VGenXL UNet. It is a conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and
+ returns a sample-shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
@@ -477,8 +477,8 @@ def disable_freeu(self):
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
@@ -533,7 +533,8 @@ def forward(
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
fps (`torch.Tensor`): Frames per second for the video being generated. Used as a "micro-condition".
image_latents (`torch.FloatTensor`): Image encodings from the VAE.
- image_embeddings (`torch.FloatTensor`): Projection embeddings of the conditioning image computed with a vision encoder.
+ image_embeddings (`torch.FloatTensor`):
+ Projection embeddings of the conditioning image computed with a vision encoder.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
cross_attention_kwargs (`dict`, *optional*):
diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py
index 88c0b967c099..595b7b03571c 100644
--- a/src/diffusers/models/unets/unet_motion_model.py
+++ b/src/diffusers/models/unets/unet_motion_model.py
@@ -709,8 +709,8 @@ def disable_freeu(self) -> None:
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py
index 5fe265e63fc5..0f89df8c6bff 100644
--- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py
+++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py
@@ -31,8 +31,8 @@ class UNetSpatioTemporalConditionOutput(BaseOutput):
class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
- A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
- shaped output.
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and
+ returns a sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
@@ -57,7 +57,8 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
The dimension of the cross attention features.
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
- [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
+ [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
+ [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
[`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
The number of attention heads.
@@ -374,12 +375,12 @@ def forward(
The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
embeddings and added to the time embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
- tuple.
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead
+ of a plain tuple.
Returns:
[`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
- If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
- a `tuple` is returned where the first element is the sample tensor.
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
+ returned, otherwise a `tuple` is returned where the first element is the sample tensor.
"""
# 1. time
timesteps = timestep
diff --git a/src/diffusers/models/unets/unet_stable_cascade.py b/src/diffusers/models/unets/unet_stable_cascade.py
index 6227f7413a3c..ff76415ecf0e 100644
--- a/src/diffusers/models/unets/unet_stable_cascade.py
+++ b/src/diffusers/models/unets/unet_stable_cascade.py
@@ -186,7 +186,8 @@ def __init__(
block_out_channels (Tuple[int], defaults to (2048, 2048)):
Tuple of output channels for each block.
num_attention_heads (Tuple[int], defaults to (32, 32)):
- Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have attention.
+ Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have
+ attention.
down_num_layers_per_block (Tuple[int], defaults to [8, 24]):
Number of layers in each down block.
up_num_layers_per_block (Tuple[int], defaults to [24, 8]):
@@ -197,10 +198,9 @@ def __init__(
Number of 1x1 Convolutional layers to repeat in each up block.
block_types_per_layer (Tuple[Tuple[str]], optional,
defaults to (
- ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
- ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock")
- ):
- Block types used in each layer of the up/down blocks.
+ ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), ("SDCascadeResBlock",
+ "SDCascadeTimestepBlock", "SDCascadeAttnBlock")
+ ): Block types used in each layer of the up/down blocks.
clip_text_in_channels (`int`, *optional*, defaults to `None`):
Number of input channels for CLIP based text conditioning.
clip_text_pooled_in_channels (`int`, *optional*, defaults to 1280):
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 2b2277809b38..ab7c13b56eb8 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -134,6 +134,12 @@
"StableDiffusionXLControlNetPipeline",
]
)
+ _import_structure["controlnet_xs"].extend(
+ [
+ "StableDiffusionControlNetXSPipeline",
+ "StableDiffusionXLControlNetXSPipeline",
+ ]
+ )
_import_structure["deepfloyd_if"] = [
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
@@ -378,6 +384,10 @@
StableDiffusionXLControlNetInpaintPipeline,
StableDiffusionXLControlNetPipeline,
)
+ from .controlnet_xs import (
+ StableDiffusionControlNetXSPipeline,
+ StableDiffusionXLControlNetXSPipeline,
+ )
from .deepfloyd_if import (
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
diff --git a/src/diffusers/pipelines/amused/pipeline_amused.py b/src/diffusers/pipelines/amused/pipeline_amused.py
index aa682b46fe70..994455ff29db 100644
--- a/src/diffusers/pipelines/amused/pipeline_amused.py
+++ b/src/diffusers/pipelines/amused/pipeline_amused.py
@@ -30,9 +30,7 @@
>>> import torch
>>> from diffusers import AmusedPipeline
- >>> pipe = AmusedPipeline.from_pretrained(
- ... "amused/amused-512", variant="fp16", torch_dtype=torch.float16
- ... )
+ >>> pipe = AmusedPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> prompt = "a photo of an astronaut riding a horse on mars"
@@ -150,10 +148,12 @@ def __call__(
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
- The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/
- and the micro-conditioning section of https://arxiv.org/abs/2307.01952.
+ The targeted aesthetic score according to the laion aesthetic classifier. See
+ https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
+ https://arxiv.org/abs/2307.01952.
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
- The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952.
+ The targeted height, width crop coordinates. See the micro-conditioning section of
+ https://arxiv.org/abs/2307.01952.
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
diff --git a/src/diffusers/pipelines/amused/pipeline_amused_img2img.py b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py
index 8b49d1a64578..1218e7a44c4d 100644
--- a/src/diffusers/pipelines/amused/pipeline_amused_img2img.py
+++ b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py
@@ -167,10 +167,12 @@ def __call__(
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
- The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/
- and the micro-conditioning section of https://arxiv.org/abs/2307.01952.
+ The targeted aesthetic score according to the laion aesthetic classifier. See
+ https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
+ https://arxiv.org/abs/2307.01952.
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
- The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952.
+ The targeted height, width crop coordinates. See the micro-conditioning section of
+ https://arxiv.org/abs/2307.01952.
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
diff --git a/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py
index 423f5734b478..ab0a55cdd388 100644
--- a/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py
+++ b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py
@@ -191,10 +191,12 @@ def __call__(
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6):
- The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/
- and the micro-conditioning section of https://arxiv.org/abs/2307.01952.
+ The targeted aesthetic score according to the laion aesthetic classifier. See
+ https://laion.ai/blog/laion-aesthetics/ and the micro-conditioning section of
+ https://arxiv.org/abs/2307.01952.
micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)):
- The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952.
+ The targeted height, width crop coordinates. See the micro-conditioning section of
+ https://arxiv.org/abs/2307.01952.
temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)):
Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`.
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
index 12347227a15e..3765db938cd5 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
@@ -639,10 +639,10 @@ def __call__(
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
index 43d334439532..106fabba721b 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
@@ -52,14 +52,21 @@
>>> from io import BytesIO
>>> from PIL import Image
- >>> adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
- >>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained("SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter).to("cuda")
- >>> pipe.scheduler = DDIMScheduler(beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace")
+ >>> adapter = MotionAdapter.from_pretrained(
+ ... "guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16
+ ... )
+ >>> pipe = AnimateDiffVideoToVideoPipeline.from_pretrained(
+ ... "SG161222/Realistic_Vision_V5.1_noVAE", motion_adapter=adapter
+ ... ).to("cuda")
+ >>> pipe.scheduler = DDIMScheduler(
+ ... beta_schedule="linear", steps_offset=1, clip_sample=False, timespace_spacing="linspace"
+ ... )
+
>>> def load_video(file_path: str):
... images = []
- ...
- ... if file_path.startswith(('http://', 'https://')):
+
+ ... if file_path.startswith(("http://", "https://")):
... # If the file_path is a URL
... response = requests.get(file_path)
... response.raise_for_status()
@@ -68,15 +75,20 @@
... else:
... # Assuming it's a local file path
... vid = imageio.get_reader(file_path)
- ...
+
... for frame in vid:
... pil_image = Image.fromarray(frame)
... images.append(pil_image)
- ...
+
... return images
- >>> video = load_video("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif")
- >>> output = pipe(video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5)
+
+ >>> video = load_video(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-vid2vid-input-1.gif"
+ ... )
+ >>> output = pipe(
+ ... video=video, prompt="panda playing a guitar, on a boat, in the ocean, high quality", strength=0.5
+ ... )
>>> frames = output.frames[0]
>>> export_to_gif(frames, "animation.gif")
```
@@ -135,8 +147,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -626,7 +638,7 @@ def prepare_latents(
# video must be a list of list of images
# the outer list denotes having multiple videos as input, whereas inner list means the frames of the video
# as a list of images
- if not isinstance(video[0], list):
+ if video and not isinstance(video[0], list):
video = [video]
if latents is None:
video = torch.cat(
@@ -799,16 +811,15 @@ def __call__(
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`AnimateDiffPipelineOutput`] instead
- of a plain tuple.
+ Whether or not to return a [`AnimateDiffPipelineOutput`] instead of a plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
diff --git a/src/diffusers/pipelines/animatediff/pipeline_output.py b/src/diffusers/pipelines/animatediff/pipeline_output.py
index 184a45848a37..97e7c87ad7f7 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_output.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_output.py
@@ -15,7 +15,8 @@ class AnimateDiffPipelineOutput(BaseOutput):
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
- List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised
PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`
"""
diff --git a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
index 69bebdd0dc4f..78b730ea916c 100644
--- a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
+++ b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
@@ -330,8 +330,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, devic
shape = (
batch_size,
num_channels_latents,
- height // self.vae_scale_factor,
- self.vocoder.config.model_in_dim // self.vae_scale_factor,
+ int(height) // self.vae_scale_factor,
+ int(self.vocoder.config.model_in_dim) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
index c0b85e4db5f6..948caf97d27b 100644
--- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
+++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
@@ -95,7 +95,14 @@ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin):
"""
@register_to_config
- def __init__(self, text_encoder_dim, text_encoder_1_dim, langauge_model_dim):
+ def __init__(
+ self,
+ text_encoder_dim,
+ text_encoder_1_dim,
+ langauge_model_dim,
+ use_learned_position_embedding=None,
+ max_seq_length=None,
+ ):
super().__init__()
# additional projection layers for each text encoder
self.projection = nn.Linear(text_encoder_dim, langauge_model_dim)
@@ -108,6 +115,14 @@ def __init__(self, text_encoder_dim, text_encoder_1_dim, langauge_model_dim):
self.sos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
self.eos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
+ self.use_learned_position_embedding = use_learned_position_embedding
+
+ # learable positional embedding for vits encoder
+ if self.use_learned_position_embedding is not None:
+ self.learnable_positional_embedding = torch.nn.Parameter(
+ torch.zeros((1, text_encoder_1_dim, max_seq_length))
+ )
+
def forward(
self,
hidden_states: Optional[torch.FloatTensor] = None,
@@ -120,6 +135,10 @@ def forward(
hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed
)
+ # Add positional embedding for Vits hidden state
+ if self.use_learned_position_embedding is not None:
+ hidden_states_1 = (hidden_states_1.permute(0, 2, 1) + self.learnable_positional_embedding).permute(0, 2, 1)
+
hidden_states_1 = self.projection_1(hidden_states_1)
hidden_states_1, attention_mask_1 = add_special_tokens(
hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1
@@ -701,8 +720,8 @@ def forward(
Returns:
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
- If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
- a `tuple` is returned where the first element is the sample tensor.
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
index e01aa9929dd8..a498831877c9 100644
--- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
+++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
@@ -27,6 +27,8 @@
T5EncoderModel,
T5Tokenizer,
T5TokenizerFast,
+ VitsModel,
+ VitsTokenizer,
)
from ...models import AutoencoderKL
@@ -79,6 +81,37 @@
>>> # save the best audio sample (index 0) as a .wav file
>>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0])
```
+ ```
+ #Using AudioLDM2 for Text To Speech
+ >>> import scipy
+ >>> import torch
+ >>> from diffusers import AudioLDM2Pipeline
+
+ >>> repo_id = "anhnct/audioldm2_gigaspeech"
+ >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> # define the prompts
+ >>> prompt = "A female reporter is speaking"
+ >>> transcript = "wish you have a good day"
+
+ >>> # set the seed for generator
+ >>> generator = torch.Generator("cuda").manual_seed(0)
+
+ >>> # run the generation
+ >>> audio = pipe(
+ ... prompt,
+ ... transcription=transcript,
+ ... num_inference_steps=200,
+ ... audio_length_in_s=10.0,
+ ... num_waveforms_per_prompt=2,
+ ... generator=generator,
+ ... max_new_tokens=512, #Must set max_new_tokens equa to 512 for TTS
+ ... ).audios
+
+ >>> # save the best audio sample (index 0) as a .wav file
+ >>> scipy.io.wavfile.write("tts.wav", rate=16000, data=audio[0])
+ ```
"""
@@ -116,20 +149,23 @@ class AudioLDM2Pipeline(DiffusionPipeline):
specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The
text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to
rank generated waveforms against the text prompt by computing similarity scores.
- text_encoder_2 ([`~transformers.T5EncoderModel`]):
+ text_encoder_2 ([`~transformers.T5EncoderModel`, `~transformers.VitsModel`]):
Second frozen text-encoder. AudioLDM2 uses the encoder of
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
- [google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant.
+ [google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant. Second frozen text-encoder use
+ for TTS. AudioLDM2 uses the encoder of
+ [Vits](https://huggingface.co/docs/transformers/model_doc/vits#transformers.VitsModel).
projection_model ([`AudioLDM2ProjectionModel`]):
A trained model used to linearly project the hidden-states from the first and second text encoder models
and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are
- concatenated to give the input to the language model.
+ concatenated to give the input to the language model. A Learned Position Embedding for the Vits
+ hidden-states
language_model ([`~transformers.GPT2Model`]):
An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected
outputs from the two text encoders.
tokenizer ([`~transformers.RobertaTokenizer`]):
Tokenizer to tokenize text for the first frozen text-encoder.
- tokenizer_2 ([`~transformers.T5Tokenizer`]):
+ tokenizer_2 ([`~transformers.T5Tokenizer`, `~transformers.VitsTokenizer`]):
Tokenizer to tokenize text for the second frozen text-encoder.
feature_extractor ([`~transformers.ClapFeatureExtractor`]):
Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring.
@@ -146,11 +182,11 @@ def __init__(
self,
vae: AutoencoderKL,
text_encoder: ClapModel,
- text_encoder_2: T5EncoderModel,
+ text_encoder_2: Union[T5EncoderModel, VitsModel],
projection_model: AudioLDM2ProjectionModel,
language_model: GPT2Model,
tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
- tokenizer_2: Union[T5Tokenizer, T5TokenizerFast],
+ tokenizer_2: Union[T5Tokenizer, T5TokenizerFast, VitsTokenizer],
feature_extractor: ClapFeatureExtractor,
unet: AudioLDM2UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
@@ -273,6 +309,7 @@ def encode_prompt(
device,
num_waveforms_per_prompt,
do_classifier_free_guidance,
+ transcription=None,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -288,6 +325,8 @@ def encode_prompt(
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
+ transcription (`str` or `List[str]`):
+ transcription of text to speech
device (`torch.device`):
torch device
num_waveforms_per_prompt (`int`):
@@ -368,16 +407,26 @@ def encode_prompt(
# Define tokenizers and text encoders
tokenizers = [self.tokenizer, self.tokenizer_2]
- text_encoders = [self.text_encoder, self.text_encoder_2]
+ is_vits_text_encoder = isinstance(self.text_encoder_2, VitsModel)
+
+ if is_vits_text_encoder:
+ text_encoders = [self.text_encoder, self.text_encoder_2.text_encoder]
+ else:
+ text_encoders = [self.text_encoder, self.text_encoder_2]
if prompt_embeds is None:
prompt_embeds_list = []
attention_mask_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ use_prompt = isinstance(
+ tokenizer, (RobertaTokenizer, RobertaTokenizerFast, T5Tokenizer, T5TokenizerFast)
+ )
text_inputs = tokenizer(
- prompt,
- padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True,
+ prompt if use_prompt else transcription,
+ padding="max_length"
+ if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer))
+ else True,
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
@@ -407,6 +456,18 @@ def encode_prompt(
prompt_embeds = prompt_embeds[:, None, :]
# make sure that we attend to this single hidden-state
attention_mask = attention_mask.new_ones((batch_size, 1))
+ elif is_vits_text_encoder:
+ # Add end_token_id and attention mask in the end of sequence phonemes
+ for text_input_id, text_attention_mask in zip(text_input_ids, attention_mask):
+ for idx, phoneme_id in enumerate(text_input_id):
+ if phoneme_id == 0:
+ text_input_id[idx] = 182
+ text_attention_mask[idx] = 1
+ break
+ prompt_embeds = text_encoder(
+ text_input_ids, attention_mask=attention_mask, padding_mask=attention_mask.unsqueeze(-1)
+ )
+ prompt_embeds = prompt_embeds[0]
else:
prompt_embeds = text_encoder(
text_input_ids,
@@ -485,7 +546,7 @@ def encode_prompt(
uncond_tokens,
padding="max_length",
max_length=tokenizer.model_max_length
- if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
+ if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast, VitsTokenizer))
else max_length,
truncation=True,
return_tensors="pt",
@@ -503,6 +564,15 @@ def encode_prompt(
negative_prompt_embeds = negative_prompt_embeds[:, None, :]
# make sure that we attend to this single hidden-state
negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1))
+ elif is_vits_text_encoder:
+ negative_prompt_embeds = torch.zeros(
+ batch_size,
+ tokenizer.model_max_length,
+ text_encoder.config.hidden_size,
+ ).to(dtype=self.text_encoder_2.dtype, device=device)
+ negative_attention_mask = torch.zeros(batch_size, tokenizer.model_max_length).to(
+ dtype=self.text_encoder_2.dtype, device=device
+ )
else:
negative_prompt_embeds = text_encoder(
uncond_input_ids,
@@ -623,6 +693,7 @@ def check_inputs(
audio_length_in_s,
vocoder_upsample_factor,
callback_steps,
+ transcription=None,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
@@ -690,6 +761,14 @@ def check_inputs(
f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
)
+ if transcription is None:
+ if self.text_encoder_2.config.model_type == "vits":
+ raise ValueError("Cannot forward without transcription. Please make sure to" " have transcription")
+ elif transcription is not None and (
+ not isinstance(transcription, str) and not isinstance(transcription, list)
+ ):
+ raise ValueError(f"`transcription` has to be of type `str` or `list` but is {type(transcription)}")
+
if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None:
if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape:
raise ValueError(
@@ -711,8 +790,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, devic
shape = (
batch_size,
num_channels_latents,
- height // self.vae_scale_factor,
- self.vocoder.config.model_in_dim // self.vae_scale_factor,
+ int(height) // self.vae_scale_factor,
+ int(self.vocoder.config.model_in_dim) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -734,6 +813,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, devic
def __call__(
self,
prompt: Union[str, List[str]] = None,
+ transcription: Union[str, List[str]] = None,
audio_length_in_s: Optional[float] = None,
num_inference_steps: int = 200,
guidance_scale: float = 3.5,
@@ -761,6 +841,8 @@ def __call__(
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
+ transcription (`str` or `List[str]`, *optional*):\
+ The transcript for text to speech.
audio_length_in_s (`int`, *optional*, defaults to 10.24):
The length of the generated audio sample in seconds.
num_inference_steps (`int`, *optional*, defaults to 200):
@@ -857,6 +939,7 @@ def __call__(
audio_length_in_s,
vocoder_upsample_factor,
callback_steps,
+ transcription,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
@@ -886,6 +969,7 @@ def __call__(
device,
num_waveforms_per_prompt,
do_classifier_free_guidance,
+ transcription,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
index 3c69fb06332c..f099d54e57cd 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
@@ -107,8 +107,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -807,7 +807,12 @@ def prepare_image(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -922,9 +927,9 @@ def __call__(
accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
`init`, images must be passed as a list such that each element of the list can be correctly batched for
- input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single ControlNet,
- each will be paired with each prompt in the `prompt` list. This also applies to multiple ControlNets,
- where a list of image lists can be passed to batch for each prompt and each ControlNet.
+ input to a single ControlNet. When `prompt` is a list, and if a list of images is passed for a single
+ ControlNet, each will be paired with each prompt in the `prompt` list. This also applies to multiple
+ ControlNets, where a list of image lists can be passed to batch for each prompt and each ControlNet.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
@@ -962,10 +967,10 @@ def __call__(
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
index 403fe6a9e797..a5a0aaed0f2e 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
@@ -978,10 +978,10 @@ def __call__(
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index ddc0983f304d..47dbb26eb3e4 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -972,7 +972,12 @@ def prepare_latents(
return_noise=False,
return_image_latents=False,
):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -1167,11 +1172,12 @@ def __call__(
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
padding_mask_crop (`int`, *optional*, defaults to `None`):
- The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
- `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
- contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
- the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
- and contain information irrelevant for inpainting, such as background.
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
+ the image is large and contain information irrelevant for inpainting, such as background.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
@@ -1207,10 +1213,10 @@ def __call__(
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
index 3eb8f31b6a26..18c4370b8025 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
@@ -880,7 +880,12 @@ def prepare_latents(
return_noise=False,
return_image_latents=False,
):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -1194,11 +1199,12 @@ def __call__(
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
padding_mask_crop (`int`, *optional*, defaults to `None`):
- The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
- `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
- contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
- the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
- and contain information irrelevant for inpainting, such as background.
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
+ the image is large and contain information irrelevant for inpainting, such as background.
strength (`float`, *optional*, defaults to 0.9999):
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
@@ -1247,10 +1253,10 @@ def __call__(
argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
index d6591aa26f2a..2307b856ad63 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
@@ -813,7 +813,12 @@ def prepare_image(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -1039,10 +1044,10 @@ def __call__(
argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
index 6c00e2f3fc4b..d32e7d81649d 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
@@ -898,6 +898,12 @@ def prepare_latents(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
+ latents_mean = latents_std = None
+ if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
+
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
@@ -935,7 +941,12 @@ def prepare_latents(
self.vae.to(dtype)
init_latents = init_latents.to(dtype)
- init_latents = self.vae.config.scaling_factor * init_latents
+ if latents_mean is not None and latents_std is not None:
+ latents_mean = latents_mean.to(device=self.device, dtype=dtype)
+ latents_std = latents_std.to(device=self.device, dtype=dtype)
+ init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
+ else:
+ init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
@@ -1178,10 +1189,10 @@ def __call__(
input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/src/diffusers/pipelines/controlnet_xs/__init__.py b/src/diffusers/pipelines/controlnet_xs/__init__.py
new file mode 100644
index 000000000000..978278b184f9
--- /dev/null
+++ b/src/diffusers/pipelines/controlnet_xs/__init__.py
@@ -0,0 +1,68 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_flax_available,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_controlnet_xs"] = ["StableDiffusionControlNetXSPipeline"]
+ _import_structure["pipeline_controlnet_xs_sd_xl"] = ["StableDiffusionXLControlNetXSPipeline"]
+try:
+ if not (is_transformers_available() and is_flax_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_flax_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
+else:
+ pass # _import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
+
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_controlnet_xs import StableDiffusionControlNetXSPipeline
+ from .pipeline_controlnet_xs_sd_xl import StableDiffusionXLControlNetXSPipeline
+
+ try:
+ if not (is_transformers_available() and is_flax_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
+ else:
+ pass # from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
+
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/examples/research_projects/controlnetxs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
similarity index 81%
rename from examples/research_projects/controlnetxs/pipeline_controlnet_xs.py
rename to src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
index 88a586e9271d..622bac8c5f97 100644
--- a/examples/research_projects/controlnetxs/pipeline_controlnet_xs.py
+++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
@@ -19,30 +19,75 @@
import PIL.Image
import torch
import torch.nn.functional as F
-from controlnetxs import ControlNetXSModel
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
-from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
-from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
-from diffusers.models import AutoencoderKL, UNet2DConditionModel
-from diffusers.models.lora import adjust_lora_scale_text_encoder
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
-from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
-from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
-from diffusers.schedulers import KarrasDiffusionSchedulers
-from diffusers.utils import (
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
+ replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
-from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install opencv-python transformers accelerate
+ >>> from diffusers import StableDiffusionControlNetXSPipeline, ControlNetXSAdapter
+ >>> from diffusers.utils import load_image
+ >>> import numpy as np
+ >>> import torch
+
+ >>> import cv2
+ >>> from PIL import Image
+
+ >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
+ >>> negative_prompt = "low quality, bad quality, sketches"
+
+ >>> # download an image
+ >>> image = load_image(
+ ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
+ ... )
+
+ >>> # initialize the models and pipeline
+ >>> controlnet_conditioning_scale = 0.5
+
+ >>> controlnet = ControlNetXSAdapter.from_pretrained(
+ ... "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
+ ... )
+ >>> pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # get canny image
+ >>> image = np.array(image)
+ >>> image = cv2.Canny(image, 100, 200)
+ >>> image = image[:, :, None]
+ >>> image = np.concatenate([image, image, image], axis=2)
+ >>> canny_image = Image.fromarray(image)
+ >>> # generate image
+ >>> image = pipe(
+ ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
+ ... ).images[0]
+ ```
+"""
+
+
class StableDiffusionControlNetXSPipeline(
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
@@ -56,7 +101,7 @@ class StableDiffusionControlNetXSPipeline(
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
@@ -66,9 +111,9 @@ class StableDiffusionControlNetXSPipeline(
tokenizer ([`~transformers.CLIPTokenizer`]):
A `CLIPTokenizer` to tokenize text.
unet ([`UNet2DConditionModel`]):
- A `UNet2DConditionModel` to denoise the encoded image latents.
- controlnet ([`ControlNetXSModel`]):
- Provides additional conditioning to the `unet` during the denoising process.
+ A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents.
+ controlnet ([`ControlNetXSAdapter`]):
+ A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
@@ -80,17 +125,18 @@ class StableDiffusionControlNetXSPipeline(
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
- model_cpu_offload_seq = "text_encoder->unet->vae>controlnet"
+ model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
- unet: UNet2DConditionModel,
- controlnet: ControlNetXSModel,
+ unet: Union[UNet2DConditionModel, UNetControlNetXSModel],
+ controlnet: ControlNetXSAdapter,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
@@ -98,6 +144,9 @@ def __init__(
):
super().__init__()
+ if isinstance(unet, UNet2DConditionModel):
+ unet = UNetControlNetXSModel.from_unet(unet, controlnet)
+
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
@@ -114,14 +163,6 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible(
- vae
- )
- if not vae_compatible:
- raise ValueError(
- f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`."
- )
-
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -403,20 +444,19 @@ def check_inputs(
self,
prompt,
image,
- callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
+ callback_on_step_end_tensor_inputs=None,
):
- if (callback_steps is None) or (
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
- f" {type(callback_steps)}."
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
@@ -445,25 +485,16 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)
- # Check `image`
+ # Check `image` and `controlnet_conditioning_scale`
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ self.unet, torch._dynamo.eval_frame.OptimizedModule
)
if (
- isinstance(self.controlnet, ControlNetXSModel)
+ isinstance(self.unet, UNetControlNetXSModel)
or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
+ and isinstance(self.unet._orig_mod, UNetControlNetXSModel)
):
self.check_image(image, prompt, prompt_embeds)
- else:
- assert False
-
- # Check `controlnet_conditioning_scale`
- if (
- isinstance(self.controlnet, ControlNetXSModel)
- or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
- ):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
else:
@@ -547,7 +578,12 @@ def prepare_image(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -563,7 +599,33 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = latents * self.scheduler.init_noise_sigma
return latents
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
+ def clip_skip(self):
+ return self._clip_skip
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
+ def num_timesteps(self):
+ return self._num_timesteps
+
@torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
@@ -581,13 +643,13 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
- callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
control_guidance_start: float = 0.0,
control_guidance_end: float = 1.0,
clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
r"""
The call function to the pipeline for generation.
@@ -595,7 +657,7 @@ def __call__(
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
- image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
@@ -639,12 +701,6 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
- callback (`Callable`, *optional*):
- A function that calls every `callback_steps` steps during inference. The function is called with the
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
- callback_steps (`int`, *optional*, defaults to 1):
- The frequency at which the `callback` function is called. If not specified, the callback is called at
- every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -659,7 +715,15 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
-
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeine class.
Examples:
Returns:
@@ -669,21 +733,27 @@ def __call__(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
- controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
image,
- callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
+ callback_on_step_end_tensor_inputs,
)
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -713,6 +783,7 @@ def __call__(
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
)
+
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
@@ -720,27 +791,24 @@ def __call__(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare image
- if isinstance(controlnet, ControlNetXSModel):
- image = self.prepare_image(
- image=image,
- width=width,
- height=height,
- batch_size=batch_size * num_images_per_prompt,
- num_images_per_prompt=num_images_per_prompt,
- device=device,
- dtype=controlnet.dtype,
- do_classifier_free_guidance=do_classifier_free_guidance,
- )
- height, width = image.shape[-2:]
- else:
- assert False
+ image = self.prepare_image(
+ image=image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=unet.dtype,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+ height, width = image.shape[-2:]
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
- num_channels_latents = self.unet.config.in_channels
+ num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
@@ -757,42 +825,33 @@ def __call__(
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
- is_unet_compiled = is_compiled_module(self.unet)
- is_controlnet_compiled = is_compiled_module(self.controlnet)
+ self._num_timesteps = len(timesteps)
+ is_controlnet_compiled = is_compiled_module(self.unet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
- if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+ if is_controlnet_compiled and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
- dont_control = (
- i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end
+ apply_control = (
+ i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end
)
- if dont_control:
- noise_pred = self.unet(
- sample=latent_model_input,
- timestep=t,
- encoder_hidden_states=prompt_embeds,
- cross_attention_kwargs=cross_attention_kwargs,
- return_dict=True,
- ).sample
- else:
- noise_pred = self.controlnet(
- base_model=self.unet,
- sample=latent_model_input,
- timestep=t,
- encoder_hidden_states=prompt_embeds,
- controlnet_cond=image,
- conditioning_scale=controlnet_conditioning_scale,
- cross_attention_kwargs=cross_attention_kwargs,
- return_dict=True,
- ).sample
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=prompt_embeds,
+ controlnet_cond=image,
+ conditioning_scale=controlnet_conditioning_scale,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=True,
+ apply_control=apply_control,
+ ).sample
# perform guidance
if do_classifier_free_guidance:
@@ -801,12 +860,18 @@ def __call__(
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
- # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
- if callback is not None and i % callback_steps == 0:
- step_idx = i // getattr(self.scheduler, "order", 1)
- callback(step_idx, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
diff --git a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
similarity index 84%
rename from examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py
rename to src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
index d0186573fa9c..3ab535a054bd 100644
--- a/examples/research_projects/controlnetxs/pipeline_controlnet_xs_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
@@ -19,41 +19,93 @@
import PIL.Image
import torch
import torch.nn.functional as F
-from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+)
+
+from diffusers.utils.import_utils import is_invisible_watermark_available
-from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
-from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
-from diffusers.models import AutoencoderKL, ControlNetXSModel, UNet2DConditionModel
-from diffusers.models.attention_processor import (
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
+from ...models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
-from diffusers.models.lora import adjust_lora_scale_text_encoder
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
-from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
-from diffusers.schedulers import KarrasDiffusionSchedulers
-from diffusers.utils import (
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
USE_PEFT_BACKEND,
logging,
+ replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
-from diffusers.utils.import_utils import is_invisible_watermark_available
-from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
if is_invisible_watermark_available():
- from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+ from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install opencv-python transformers accelerate
+ >>> from diffusers import StableDiffusionXLControlNetXSPipeline, ControlNetXSAdapter, AutoencoderKL
+ >>> from diffusers.utils import load_image
+ >>> import numpy as np
+ >>> import torch
+
+ >>> import cv2
+ >>> from PIL import Image
+
+ >>> prompt = "aerial view, a futuristic research complex in a bright foggy jungle, hard lighting"
+ >>> negative_prompt = "low quality, bad quality, sketches"
+
+ >>> # download an image
+ >>> image = load_image(
+ ... "https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
+ ... )
+
+ >>> # initialize the models and pipeline
+ >>> controlnet_conditioning_scale = 0.5
+ >>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
+ >>> controlnet = ControlNetXSAdapter.from_pretrained(
+ ... "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16
+ ... )
+ >>> pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # get canny image
+ >>> image = np.array(image)
+ >>> image = cv2.Canny(image, 100, 200)
+ >>> image = image[:, :, None]
+ >>> image = np.concatenate([image, image, image], axis=2)
+ >>> canny_image = Image.fromarray(image)
+
+ >>> # generate image
+ >>> image = pipe(
+ ... prompt, controlnet_conditioning_scale=controlnet_conditioning_scale, image=canny_image
+ ... ).images[0]
+ ```
+"""
+
+
class StableDiffusionXLControlNetXSPipeline(
DiffusionPipeline,
- StableDiffusionMixin,
TextualInversionLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
FromSingleFileMixin,
@@ -66,9 +118,8 @@ class StableDiffusionXLControlNetXSPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
@@ -83,9 +134,9 @@ class StableDiffusionXLControlNetXSPipeline(
tokenizer_2 ([`~transformers.CLIPTokenizer`]):
A `CLIPTokenizer` to tokenize text.
unet ([`UNet2DConditionModel`]):
- A `UNet2DConditionModel` to denoise the encoded image latents.
- controlnet ([`ControlNetXSModel`]:
- Provides additional conditioning to the `unet` during the denoising process.
+ A [`UNet2DConditionModel`] used to create a UNetControlNetXSModel to denoise the encoded image latents.
+ controlnet ([`ControlNetXSAdapter`]):
+ A [`ControlNetXSAdapter`] to be used in combination with `unet` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
@@ -98,9 +149,15 @@ class StableDiffusionXLControlNetXSPipeline(
watermarker is used.
"""
- # leave controlnet out on purpose because it iterates with unet
- model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae->controlnet"
- _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "feature_extractor",
+ ]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -109,21 +166,17 @@ def __init__(
text_encoder_2: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
- unet: UNet2DConditionModel,
- controlnet: ControlNetXSModel,
+ unet: Union[UNet2DConditionModel, UNetControlNetXSModel],
+ controlnet: ControlNetXSAdapter,
scheduler: KarrasDiffusionSchedulers,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
+ feature_extractor: CLIPImageProcessor = None,
):
super().__init__()
- vae_compatible, cnxs_condition_downsample_factor, vae_downsample_factor = controlnet._check_if_vae_compatible(
- vae
- )
- if not vae_compatible:
- raise ValueError(
- f"The downsampling factors of the VAE ({vae_downsample_factor}) and the conditioning part of ControlNetXS model {cnxs_condition_downsample_factor} need to be equal. Consider building the ControlNetXS model with different `conditioning_block_sizes`."
- )
+ if isinstance(unet, UNet2DConditionModel):
+ unet = UNetControlNetXSModel.from_unet(unet, controlnet)
self.register_modules(
vae=vae,
@@ -134,6 +187,7 @@ def __init__(
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
+ feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
@@ -407,7 +461,6 @@ def check_inputs(
prompt,
prompt_2,
image,
- callback_steps,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
@@ -417,13 +470,13 @@ def check_inputs(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
+ callback_on_step_end_tensor_inputs=None,
):
- if (callback_steps is None) or (
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
- f" {type(callback_steps)}."
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
@@ -474,25 +527,16 @@ def check_inputs(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
- # Check `image`
+ # Check `image` and ``controlnet_conditioning_scale``
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
+ self.unet, torch._dynamo.eval_frame.OptimizedModule
)
if (
- isinstance(self.controlnet, ControlNetXSModel)
+ isinstance(self.unet, UNetControlNetXSModel)
or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
+ and isinstance(self.unet._orig_mod, UNetControlNetXSModel)
):
self.check_image(image, prompt, prompt_embeds)
- else:
- assert False
-
- # Check `controlnet_conditioning_scale`
- if (
- isinstance(self.controlnet, ControlNetXSModel)
- or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetXSModel)
- ):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
else:
@@ -577,7 +621,12 @@ def prepare_image(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -593,7 +642,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
latents = latents * self.scheduler.init_noise_sigma
return latents
- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
def _get_add_time_ids(
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
):
@@ -602,7 +650,7 @@ def _get_add_time_ids(
passed_add_embed_dim = (
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
)
- expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+ expected_add_embed_dim = self.unet.base_add_embedding.linear_1.in_features
if expected_add_embed_dim != passed_add_embed_dim:
raise ValueError(
@@ -632,7 +680,33 @@ def upcast_vae(self):
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
+ def clip_skip(self):
+ return self._clip_skip
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
+ def num_timesteps(self):
+ return self._num_timesteps
+
@torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
@@ -654,8 +728,6 @@ def __call__(
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
- callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
control_guidance_start: float = 0.0,
@@ -667,6 +739,8 @@ def __call__(
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
r"""
The call function to the pipeline for generation.
@@ -677,7 +751,7 @@ def __call__(
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders.
- image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
@@ -735,12 +809,6 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
- callback (`Callable`, *optional*):
- A function that calls every `callback_steps` steps during inference. The function is called with the
- following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
- callback_steps (`int`, *optional*, defaults to 1):
- The frequency at which the `callback` function is called. If not specified, the callback is called at
- every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -783,6 +851,15 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeine class.
Examples:
@@ -791,14 +868,14 @@ def __call__(
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] is
returned, otherwise a `tuple` is returned containing the output images.
"""
- controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ unet = self.unet._orig_mod if is_compiled_module(self.unet) else self.unet
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
image,
- callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
@@ -808,8 +885,14 @@ def __call__(
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
+ callback_on_step_end_tensor_inputs,
)
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -850,7 +933,7 @@ def __call__(
)
# 4. Prepare image
- if isinstance(controlnet, ControlNetXSModel):
+ if isinstance(unet, UNetControlNetXSModel):
image = self.prepare_image(
image=image,
width=width,
@@ -858,7 +941,7 @@ def __call__(
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
- dtype=controlnet.dtype,
+ dtype=unet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
height, width = image.shape[-2:]
@@ -870,7 +953,7 @@ def __call__(
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
- num_channels_latents = self.unet.config.in_channels
+ num_channels_latents = self.unet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
@@ -928,14 +1011,14 @@ def __call__(
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
- is_unet_compiled = is_compiled_module(self.unet)
- is_controlnet_compiled = is_compiled_module(self.controlnet)
+ self._num_timesteps = len(timesteps)
+ is_controlnet_compiled = is_compiled_module(self.unet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
- if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+ if is_controlnet_compiled and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
@@ -944,30 +1027,20 @@ def __call__(
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# predict the noise residual
- dont_control = (
- i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end
+ apply_control = (
+ i / len(timesteps) >= control_guidance_start and (i + 1) / len(timesteps) <= control_guidance_end
)
- if dont_control:
- noise_pred = self.unet(
- sample=latent_model_input,
- timestep=t,
- encoder_hidden_states=prompt_embeds,
- cross_attention_kwargs=cross_attention_kwargs,
- added_cond_kwargs=added_cond_kwargs,
- return_dict=True,
- ).sample
- else:
- noise_pred = self.controlnet(
- base_model=self.unet,
- sample=latent_model_input,
- timestep=t,
- encoder_hidden_states=prompt_embeds,
- controlnet_cond=image,
- conditioning_scale=controlnet_conditioning_scale,
- cross_attention_kwargs=cross_attention_kwargs,
- added_cond_kwargs=added_cond_kwargs,
- return_dict=True,
- ).sample
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=prompt_embeds,
+ controlnet_cond=image,
+ conditioning_scale=controlnet_conditioning_scale,
+ cross_attention_kwargs=cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=True,
+ apply_control=apply_control,
+ ).sample
# perform guidance
if do_classifier_free_guidance:
@@ -977,12 +1050,24 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
- if callback is not None and i % callback_steps == 0:
- step_idx = i // getattr(self.scheduler, "order", 1)
- callback(step_idx, t, latents)
+
+ # manually for max memory savings
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
index 7adf9e9c47cf..5bd396b20fad 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
@@ -12,7 +12,6 @@
from ...schedulers import DDPMScheduler
from ...utils import (
BACKENDS_MAPPING,
- is_accelerate_available,
is_bs4_available,
is_ftfy_available,
logging,
@@ -115,6 +114,7 @@ class IFPipeline(DiffusionPipeline, LoraLoaderMixin):
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
model_cpu_offload_seq = "text_encoder->unet"
+ _exclude_from_cpu_offload = ["watermarker"]
def __init__(
self,
@@ -156,20 +156,6 @@ def __init__(
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
- def remove_all_hooks(self):
- if is_accelerate_available():
- from accelerate.hooks import remove_hook_from_module
- else:
- raise ImportError("Please install accelerate via `pip install accelerate`")
-
- for model in [self.text_encoder, self.unet, self.safety_checker]:
- if model is not None:
- remove_hook_from_module(model, recurse=True)
-
- self.unet_offload_hook = None
- self.text_encoder_offload_hook = None
- self.final_offload_hook = None
-
@torch.no_grad()
def encode_prompt(
self,
@@ -335,9 +321,6 @@ def run_safety_checker(self, image, device, dtype):
nsfw_detected = None
watermark_detected = None
- if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
- self.unet_offload_hook.offload()
-
return image, nsfw_detected, watermark_detected
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
@@ -691,6 +674,9 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(0)
+
# 5. Prepare intermediate images
intermediate_images = self.prepare_intermediate_images(
batch_size * num_images_per_prompt,
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
index 99633ee21535..50e2cda25a48 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
@@ -15,7 +15,6 @@
from ...utils import (
BACKENDS_MAPPING,
PIL_INTERPOLATION,
- is_accelerate_available,
is_bs4_available,
is_ftfy_available,
logging,
@@ -139,6 +138,7 @@ class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
model_cpu_offload_seq = "text_encoder->unet"
+ _exclude_from_cpu_offload = ["watermarker"]
def __init__(
self,
@@ -180,21 +180,6 @@ def __init__(
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
- # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
- def remove_all_hooks(self):
- if is_accelerate_available():
- from accelerate.hooks import remove_hook_from_module
- else:
- raise ImportError("Please install accelerate via `pip install accelerate`")
-
- for model in [self.text_encoder, self.unet, self.safety_checker]:
- if model is not None:
- remove_hook_from_module(model, recurse=True)
-
- self.unet_offload_hook = None
- self.text_encoder_offload_hook = None
- self.final_offload_hook = None
-
@torch.no_grad()
def encode_prompt(
self,
@@ -361,9 +346,6 @@ def run_safety_checker(self, image, device, dtype):
nsfw_detected = None
watermark_detected = None
- if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
- self.unet_offload_hook.offload()
-
return image, nsfw_detected, watermark_detected
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
@@ -633,12 +615,15 @@ def numpy_to_pt(images):
return image
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
- timesteps = self.scheduler.timesteps[t_start:]
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
index 19c4f1d3909a..89eb97a087f8 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
@@ -16,7 +16,6 @@
from ...utils import (
BACKENDS_MAPPING,
PIL_INTERPOLATION,
- is_accelerate_available,
is_bs4_available,
is_ftfy_available,
logging,
@@ -143,6 +142,7 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor"]
model_cpu_offload_seq = "text_encoder->unet"
+ _exclude_from_cpu_offload = ["watermarker"]
def __init__(
self,
@@ -191,21 +191,6 @@ def __init__(
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
- # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
- def remove_all_hooks(self):
- if is_accelerate_available():
- from accelerate.hooks import remove_hook_from_module
- else:
- raise ImportError("Please install accelerate via `pip install accelerate`")
-
- for model in [self.text_encoder, self.unet, self.safety_checker]:
- if model is not None:
- remove_hook_from_module(model, recurse=True)
-
- self.unet_offload_hook = None
- self.text_encoder_offload_hook = None
- self.final_offload_hook = None
-
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
@@ -513,9 +498,6 @@ def run_safety_checker(self, image, device, dtype):
nsfw_detected = None
watermark_detected = None
- if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
- self.unet_offload_hook.offload()
-
return image, nsfw_detected, watermark_detected
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
@@ -714,13 +696,15 @@ def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device
return image
- # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
- timesteps = self.scheduler.timesteps[t_start:]
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@@ -1010,8 +994,6 @@ def __call__(
nsfw_detected = None
watermark_detected = None
- if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
- self.unet_offload_hook.offload()
else:
# 10. Post-processing
image = (image / 2 + 0.5).clamp(0, 1)
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
index 66a185b24f21..aabe1107fc9e 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
@@ -15,7 +15,6 @@
from ...utils import (
BACKENDS_MAPPING,
PIL_INTERPOLATION,
- is_accelerate_available,
is_bs4_available,
is_ftfy_available,
logging,
@@ -142,6 +141,7 @@ class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin):
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
model_cpu_offload_seq = "text_encoder->unet"
+ _exclude_from_cpu_offload = ["watermarker"]
def __init__(
self,
@@ -183,21 +183,6 @@ def __init__(
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
- # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
- def remove_all_hooks(self):
- if is_accelerate_available():
- from accelerate.hooks import remove_hook_from_module
- else:
- raise ImportError("Please install accelerate via `pip install accelerate`")
-
- for model in [self.text_encoder, self.unet, self.safety_checker]:
- if model is not None:
- remove_hook_from_module(model, recurse=True)
-
- self.unet_offload_hook = None
- self.text_encoder_offload_hook = None
- self.final_offload_hook = None
-
@torch.no_grad()
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.encode_prompt
def encode_prompt(
@@ -365,9 +350,6 @@ def run_safety_checker(self, image, device, dtype):
nsfw_detected = None
watermark_detected = None
- if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
- self.unet_offload_hook.offload()
-
return image, nsfw_detected, watermark_detected
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
@@ -723,13 +705,15 @@ def preprocess_mask_image(self, mask_image) -> torch.Tensor:
return mask_image
- # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
- timesteps = self.scheduler.timesteps[t_start:]
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
index 5c01dfdc299c..1798e0dec7eb 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
@@ -16,7 +16,6 @@
from ...utils import (
BACKENDS_MAPPING,
PIL_INTERPOLATION,
- is_accelerate_available,
is_bs4_available,
is_ftfy_available,
logging,
@@ -145,6 +144,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
model_cpu_offload_seq = "text_encoder->unet"
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
+ _exclude_from_cpu_offload = ["watermarker"]
def __init__(
self,
@@ -193,21 +193,6 @@ def __init__(
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
- # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
- def remove_all_hooks(self):
- if is_accelerate_available():
- from accelerate.hooks import remove_hook_from_module
- else:
- raise ImportError("Please install accelerate via `pip install accelerate`")
-
- for model in [self.text_encoder, self.unet, self.safety_checker]:
- if model is not None:
- remove_hook_from_module(model, recurse=True)
-
- self.unet_offload_hook = None
- self.text_encoder_offload_hook = None
- self.final_offload_hook = None
-
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
@@ -515,9 +500,6 @@ def run_safety_checker(self, image, device, dtype):
nsfw_detected = None
watermark_detected = None
- if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
- self.unet_offload_hook.offload()
-
return image, nsfw_detected, watermark_detected
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
@@ -800,13 +782,15 @@ def preprocess_mask_image(self, mask_image) -> torch.Tensor:
return mask_image
- # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
- timesteps = self.scheduler.timesteps[t_start:]
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
index a293343ebeea..36ed34cba98e 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
@@ -15,7 +15,6 @@
from ...schedulers import DDPMScheduler
from ...utils import (
BACKENDS_MAPPING,
- is_accelerate_available,
is_bs4_available,
is_ftfy_available,
logging,
@@ -101,6 +100,7 @@ class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
_optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
model_cpu_offload_seq = "text_encoder->unet"
+ _exclude_from_cpu_offload = ["watermarker"]
def __init__(
self,
@@ -149,21 +149,6 @@ def __init__(
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
- # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.remove_all_hooks
- def remove_all_hooks(self):
- if is_accelerate_available():
- from accelerate.hooks import remove_hook_from_module
- else:
- raise ImportError("Please install accelerate via `pip install accelerate`")
-
- for model in [self.text_encoder, self.unet, self.safety_checker]:
- if model is not None:
- remove_hook_from_module(model, recurse=True)
-
- self.unet_offload_hook = None
- self.text_encoder_offload_hook = None
- self.final_offload_hook = None
-
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
@@ -471,9 +456,6 @@ def run_safety_checker(self, image, device, dtype):
nsfw_detected = None
watermark_detected = None
- if hasattr(self, "unet_offload_hook") and self.unet_offload_hook is not None:
- self.unet_offload_hook.offload()
-
return image, nsfw_detected, watermark_detected
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline.prepare_extra_step_kwargs
@@ -775,6 +757,9 @@ def __call__(
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(0)
+
# 5. Prepare intermediate images
num_channels = self.unet.config.in_channels // 2
intermediate_images = self.prepare_intermediate_images(
diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
index e4583699e79e..92de940332c4 100644
--- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
@@ -89,8 +89,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -588,7 +588,12 @@ def check_inputs(
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
index 156e52c249d9..48b3b96483d5 100644
--- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
+++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
@@ -129,8 +129,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py
index dee93fc2eb53..f44a1ca74ee4 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py
@@ -469,7 +469,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py
index ddc866ef9b86..9421531d273e 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py
@@ -448,7 +448,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
index c819e5728181..5f744578810b 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
@@ -661,7 +661,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
index 550756cd80d8..3c3bd526692d 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
@@ -1000,8 +1000,8 @@ def disable_freeu(self):
def fuse_qkv_projections(self):
"""
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
@@ -1112,8 +1112,8 @@ def forward(
Returns:
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
- If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
- a `tuple` is returned where the first element is the sample tensor.
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
@@ -2238,6 +2238,7 @@ def __init__(
self,
in_channels: int,
temb_channels: int,
+ out_channels: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
@@ -2245,6 +2246,7 @@ def __init__(
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
+ resnet_groups_out: Optional[int] = None,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
@@ -2256,6 +2258,10 @@ def __init__(
):
super().__init__()
+ out_channels = out_channels or in_channels
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
@@ -2264,14 +2270,17 @@ def __init__(
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+ resnet_groups_out = resnet_groups_out or resnet_groups
+
# there is always at least one resnet
resnets = [
ResnetBlockFlat(
in_channels=in_channels,
- out_channels=in_channels,
+ out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
+ groups_out=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
@@ -2286,11 +2295,11 @@ def __init__(
attentions.append(
Transformer2DModel(
num_attention_heads,
- in_channels // num_attention_heads,
- in_channels=in_channels,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
num_layers=transformer_layers_per_block[i],
cross_attention_dim=cross_attention_dim,
- norm_num_groups=resnet_groups,
+ norm_num_groups=resnet_groups_out,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
@@ -2300,8 +2309,8 @@ def __init__(
attentions.append(
DualTransformer2DModel(
num_attention_heads,
- in_channels // num_attention_heads,
- in_channels=in_channels,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
@@ -2309,11 +2318,11 @@ def __init__(
)
resnets.append(
ResnetBlockFlat(
- in_channels=in_channels,
- out_channels=in_channels,
+ in_channels=out_channels,
+ out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
- groups=resnet_groups,
+ groups=resnet_groups_out,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
index 8af739bbe428..b1117044cf18 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
@@ -348,7 +348,12 @@ def check_inputs(self, prompt, image, height, width, callback_steps):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
index 345c15f18d89..59aa370ec2f6 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
@@ -214,7 +214,12 @@ def check_inputs(self, image, height, width, callback_steps):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
index 0b2518f7e244..0c76e5837b99 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
@@ -300,7 +300,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/free_init_utils.py b/src/diffusers/pipelines/free_init_utils.py
index a6eabc930172..4f7965a038c5 100644
--- a/src/diffusers/pipelines/free_init_utils.py
+++ b/src/diffusers/pipelines/free_init_utils.py
@@ -41,20 +41,20 @@ def enable_free_init(
num_iters (`int`, *optional*, defaults to `3`):
Number of FreeInit noise re-initialization iterations.
use_fast_sampling (`bool`, *optional*, defaults to `False`):
- Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables
- the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`.
+ Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables the
+ "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`.
method (`str`, *optional*, defaults to `butterworth`):
- Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the
- FreeInit low pass filter.
+ Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the FreeInit low
+ pass filter.
order (`int`, *optional*, defaults to `4`):
Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour
whereas lower values lead to `gaussian` method behaviour.
spatial_stop_frequency (`float`, *optional*, defaults to `0.25`):
- Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in
- the original implementation.
+ Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in the
+ original implementation.
temporal_stop_frequency (`float`, *optional*, defaults to `0.25`):
- Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in
- the original implementation.
+ Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in the
+ original implementation.
"""
self._free_init_num_iters = num_iters
self._free_init_use_fast_sampling = use_fast_sampling
diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
index cb6f3e300904..a6b9499f5542 100644
--- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
+++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
@@ -43,10 +43,14 @@
>>> from diffusers import I2VGenXLPipeline
>>> from diffusers.utils import export_to_gif, load_image
- >>> pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
+ >>> pipeline = I2VGenXLPipeline.from_pretrained(
+ ... "ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16"
+ ... )
>>> pipeline.enable_model_cpu_offload()
- >>> image_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png"
+ >>> image_url = (
+ ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png"
+ ... )
>>> image = load_image(image_url).convert("RGB")
>>> prompt = "Papers were floating in the air on a table in the library"
@@ -59,7 +63,7 @@
... num_inference_steps=50,
... negative_prompt=negative_prompt,
... guidance_scale=9.0,
- ... generator=generator
+ ... generator=generator,
... ).frames[0]
>>> video_path = export_to_gif(frames, "i2v.gif")
```
@@ -95,7 +99,8 @@ class I2VGenXLPipelineOutput(BaseOutput):
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
- List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised
PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`
"""
@@ -551,7 +556,8 @@ def __call__(
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
target_fps (`int`, *optional*):
- Frames per second. The rate at which the generated images shall be exported to a video after generation. This is also used as a "micro-condition" while generation.
+ Frames per second. The rate at which the generated images shall be exported to a video after
+ generation. This is also used as a "micro-condition" while generation.
num_frames (`int`, *optional*):
The number of video frames to generate.
num_inference_steps (`int`, *optional*):
@@ -568,9 +574,9 @@ def __call__(
num_videos_per_prompt (`int`, *optional*):
The number of images to generate per prompt.
decode_chunk_size (`int`, *optional*):
- The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
- between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
- for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
+ The number of frames to decode at a time. The higher the chunk size, the higher the temporal
+ consistency between frames, but also the higher the memory consumption. By default, the decoder will
+ decode all frames at once for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
index 6b1ed62f8ae6..cbe66a63f4c8 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py
@@ -143,6 +143,7 @@ class KandinskyCombinedPipeline(DiffusionPipeline):
_load_connected_pipes = True
model_cpu_offload_seq = "text_encoder->unet->movq->prior_prior->prior_image_encoder->prior_text_encoder"
+ _exclude_from_cpu_offload = ["prior_prior"]
def __init__(
self,
@@ -360,6 +361,7 @@ class KandinskyImg2ImgCombinedPipeline(DiffusionPipeline):
_load_connected_pipes = True
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->" "text_encoder->unet->movq"
+ _exclude_from_cpu_offload = ["prior_prior"]
def __init__(
self,
@@ -600,6 +602,7 @@ class KandinskyInpaintCombinedPipeline(DiffusionPipeline):
_load_connected_pipes = True
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->prior_prior->text_encoder->unet->movq"
+ _exclude_from_cpu_offload = ["prior_prior"]
def __init__(
self,
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
index 20f5d45bb214..06d94d2cb79f 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py
@@ -135,6 +135,7 @@ class KandinskyV22CombinedPipeline(DiffusionPipeline):
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->unet->movq"
_load_connected_pipes = True
+ _exclude_from_cpu_offload = ["prior_prior"]
def __init__(
self,
@@ -362,6 +363,7 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->unet->movq"
_load_connected_pipes = True
+ _exclude_from_cpu_offload = ["prior_prior"]
def __init__(
self,
@@ -610,6 +612,7 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
model_cpu_offload_seq = "prior_text_encoder->prior_image_encoder->unet->movq"
_load_connected_pipes = True
+ _exclude_from_cpu_offload = ["prior_prior"]
def __init__(
self,
diff --git a/src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py b/src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py
index 4fe8c54eb7fc..5360632275b4 100755
--- a/src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py
+++ b/src/diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py
@@ -35,10 +35,10 @@
def convert_state_dict(unet_state_dict):
"""
- Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model.
Args:
- unet_model (torch.nn.Module): The original U-Net model.
- unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with.
+ Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model.
+ unet_model (torch.nn.Module): The original U-Net model. unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet
+ model to match keys with.
Returns:
OrderedDict: The converted state dictionary.
diff --git a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
index fcf7ddcb9966..85d6418d07cf 100644
--- a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
+++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
@@ -8,7 +8,6 @@
from ...schedulers import DDPMScheduler
from ...utils import (
deprecate,
- is_accelerate_available,
logging,
replace_example_docstring,
)
@@ -24,7 +23,9 @@
>>> from diffusers import AutoPipelineForText2Image
>>> import torch
- >>> pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
+ >>> pipe = AutoPipelineForText2Image.from_pretrained(
+ ... "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
+ ... )
>>> pipe.enable_model_cpu_offload()
>>> prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
@@ -70,20 +71,6 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq
)
- def remove_all_hooks(self):
- if is_accelerate_available():
- from accelerate.hooks import remove_hook_from_module
- else:
- raise ImportError("Please install accelerate via `pip install accelerate`")
-
- for model in [self.text_encoder, self.unet, self.movq]:
- if model is not None:
- remove_hook_from_module(model, recurse=True)
-
- self.unet_offload_hook = None
- self.text_encoder_offload_hook = None
- self.final_offload_hook = None
-
def process_embeds(self, embeddings, attention_mask, cut_context):
if cut_context:
embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0])
diff --git a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
index 7f4164a04d1e..16a57b6b8c3f 100644
--- a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
+++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
@@ -12,7 +12,6 @@
from ...schedulers import DDPMScheduler
from ...utils import (
deprecate,
- is_accelerate_available,
logging,
replace_example_docstring,
)
@@ -29,11 +28,15 @@
>>> from diffusers.utils import load_image
>>> import torch
- >>> pipe = AutoPipelineForImage2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
+ >>> pipe = AutoPipelineForImage2Image.from_pretrained(
+ ... "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
+ ... )
>>> pipe.enable_model_cpu_offload()
>>> prompt = "A painting of the inside of a subway train with tiny raccoons."
- >>> image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png")
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png"
+ ... )
>>> generator = torch.Generator(device="cpu").manual_seed(0)
>>> image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0]
@@ -92,20 +95,6 @@ def get_timesteps(self, num_inference_steps, strength, device):
return timesteps, num_inference_steps - t_start
- def remove_all_hooks(self):
- if is_accelerate_available():
- from accelerate.hooks import remove_hook_from_module
- else:
- raise ImportError("Please install accelerate via `pip install accelerate`")
-
- for model in [self.text_encoder, self.unet]:
- if model is not None:
- remove_hook_from_module(model, recurse=True)
-
- self.unet_offload_hook = None
- self.text_encoder_offload_hook = None
- self.final_offload_hook = None
-
def _process_embeds(self, embeddings, attention_mask, cut_context):
# return embeddings, attention_mask
if cut_context:
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
index e8482ffe9ce2..8957d7140ef1 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
@@ -73,8 +73,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -749,10 +749,10 @@ def __call__(
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
index 259a65c80782..a69d49f1ffbb 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
@@ -77,8 +77,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -474,7 +474,12 @@ def run_safety_checker(self, image, device, dtype):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -681,10 +686,10 @@ def __call__(
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
index a6357c4cd3a1..619be13a8f36 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
@@ -40,30 +40,21 @@
>>> from io import BytesIO
>>> from diffusers import LEditsPPPipelineStableDiffusion
+ >>> from diffusers.utils import load_image
>>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
- >>> def download_image(url):
- ... response = requests.get(url)
- ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
-
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png"
- >>> image = download_image(img_url)
+ >>> image = load_image(img_url).convert("RGB")
- >>> _ = pipe.invert(
- ... image = image,
- ... num_inversion_steps=50,
- ... skip=0.1
- ... )
+ >>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1)
>>> edited_image = pipe(
- ... editing_prompt=["cherry blossom"],
- ... edit_guidance_scale=10.0,
- ... edit_threshold=0.75,
- ).images[0]
+ ... editing_prompt=["cherry blossom"], edit_guidance_scale=10.0, edit_threshold=0.75
+ ... ).images[0]
```
"""
@@ -279,8 +270,8 @@ class LEditsPPPipelineStableDiffusion(
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
- [`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]. If any other scheduler is passed it will automatically
- be set to [`DPMSolverMultistepScheduler`].
+ [`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]. If any other scheduler is passed it will
+ automatically be set to [`DPMSolverMultistepScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
@@ -531,8 +522,7 @@ def encode_prompt(
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
editing_prompt (`str` or `List[str]`, *optional*):
- Editing prompt(s) to be encoded. If not defined, one has to pass
- `editing_prompt_embeds` instead.
+ Editing prompt(s) to be encoded. If not defined, one has to pass `editing_prompt_embeds` instead.
editing_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -734,8 +724,9 @@ def __call__(
**kwargs,
):
r"""
- The call function to the pipeline for editing. The [`~pipelines.ledits_pp.LEditsPPPipelineStableDiffusion.invert`]
- method has to be called beforehand. Edits will always be performed for the last inverted image(s).
+ The call function to the pipeline for editing. The
+ [`~pipelines.ledits_pp.LEditsPPPipelineStableDiffusion.invert`] method has to be called beforehand. Edits will
+ always be performed for the last inverted image(s).
Args:
negative_prompt (`str` or `List[str]`, *optional*):
@@ -748,49 +739,51 @@ def __call__(
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] instead of a
- plain tuple.
+ Whether or not to return a [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] instead of a plain
+ tuple.
editing_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. The image is reconstructed by setting
- `editing_prompt = None`. Guidance direction of prompt should be specified via `reverse_editing_direction`.
+ `editing_prompt = None`. Guidance direction of prompt should be specified via
+ `reverse_editing_direction`.
editing_prompt_embeds (`torch.Tensor>`, *optional*):
- Pre-computed embeddings to use for guiding the image generation. Guidance direction of embedding should be
- specified via `reverse_editing_direction`.
+ Pre-computed embeddings to use for guiding the image generation. Guidance direction of embedding should
+ be specified via `reverse_editing_direction`.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`):
Whether the corresponding prompt in `editing_prompt` should be increased or decreased.
edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5):
- Guidance scale for guiding the image generation. If provided as list values should correspond to `editing_prompt`.
- `edit_guidance_scale` is defined as `s_e` of equation 12 of
- [LEDITS++ Paper](https://arxiv.org/abs/2301.12247).
+ Guidance scale for guiding the image generation. If provided as list values should correspond to
+ `editing_prompt`. `edit_guidance_scale` is defined as `s_e` of equation 12 of [LEDITS++
+ Paper](https://arxiv.org/abs/2301.12247).
edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10):
Number of diffusion steps (for each prompt) for which guidance will not be applied.
edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`):
Number of diffusion steps (for each prompt) after which guidance will no longer be applied.
edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9):
Masking threshold of guidance. Threshold should be proportional to the image region that is modified.
- 'edit_threshold' is defined as 'λ' of equation 12 of [LEDITS++ Paper](https://arxiv.org/abs/2301.12247).
+ 'edit_threshold' is defined as 'λ' of equation 12 of [LEDITS++
+ Paper](https://arxiv.org/abs/2301.12247).
user_mask (`torch.FloatTensor`, *optional*):
- User-provided mask for even better control over the editing process. This is helpful when LEDITS++'s implicit
- masks do not meet user preferences.
+ User-provided mask for even better control over the editing process. This is helpful when LEDITS++'s
+ implicit masks do not meet user preferences.
sem_guidance (`List[torch.Tensor]`, *optional*):
List of pre-generated guidance vectors to be applied at generation. Length of the list has to
correspond to `num_inference_steps`.
use_cross_attn_mask (`bool`, defaults to `False`):
Whether cross-attention masks are used. Cross-attention masks are always used when use_intersect_mask
- is set to true. Cross-attention masks are defined as 'M^1' of equation 12 of
- [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf).
+ is set to true. Cross-attention masks are defined as 'M^1' of equation 12 of [LEDITS++
+ paper](https://arxiv.org/pdf/2311.16711.pdf).
use_intersect_mask (`bool`, defaults to `True`):
- Whether the masking term is calculated as intersection of cross-attention masks and masks derived
- from the noise estimate. Cross-attention mask are defined as 'M^1' and masks derived from the noise
- estimate are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf).
+ Whether the masking term is calculated as intersection of cross-attention masks and masks derived from
+ the noise estimate. Cross-attention mask are defined as 'M^1' and masks derived from the noise estimate
+ are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf).
attn_store_steps (`List[int]`, *optional*):
Steps for which the attention maps are stored in the AttentionStore. Just for visualization purposes.
store_averaged_over_steps (`bool`, defaults to `True`):
- Whether the attention maps for the 'attn_store_steps' are stored averaged over the diffusion steps.
- If False, attention maps for each step are stores separately. Just for visualization purposes.
+ Whether the attention maps for the 'attn_store_steps' are stored averaged over the diffusion steps. If
+ False, attention maps for each step are stores separately. Just for visualization purposes.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -815,10 +808,10 @@ def __call__(
Returns:
[`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] or `tuple`:
- [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] if `return_dict` is True,
- otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the
- second element is a list of `bool`s denoting whether the corresponding generated image likely represents
- "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
+ [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
+ returning a tuple, the first element is a list with the generated images, and the second element is a list
+ of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
+ content, according to the `safety_checker`.
"""
if self.inversion_steps is None:
@@ -1219,9 +1212,9 @@ def invert(
crops_coords: Optional[Tuple[int, int, int, int]] = None,
):
r"""
- The function to the pipeline for image inversion as described by the [LEDITS++ Paper](https://arxiv.org/abs/2301.12247).
- If the scheduler is set to [`~schedulers.DDIMScheduler`] the inversion proposed by [edit-friendly DPDM](https://arxiv.org/abs/2304.06140)
- will be performed instead.
+ The function to the pipeline for image inversion as described by the [LEDITS++
+ Paper](https://arxiv.org/abs/2301.12247). If the scheduler is set to [`~schedulers.DDIMScheduler`] the
+ inversion proposed by [edit-friendly DPDM](https://arxiv.org/abs/2304.06140) will be performed instead.
Args:
image (`PipelineImageInput`):
@@ -1238,8 +1231,8 @@ def invert(
Portion of initial steps that will be ignored for inversion and subsequent generation. Lower values
will lead to stronger changes to the input image. `skip` has to be between `0` and `1`.
generator (`torch.Generator`, *optional*):
- A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
- inversion deterministic.
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make inversion
+ deterministic.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -1247,23 +1240,24 @@ def invert(
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
height (`int`, *optional*, defaults to `None`):
- The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
+ The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
+ height.
width (`int`, *optional*`, defaults to `None`):
- The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
+ The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
resize_mode (`str`, *optional*, defaults to `default`):
- The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
- within the specified width and height, and it may not maintaining the original aspect ratio.
- If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
- within the dimensions, filling empty with data from image.
- If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
- within the dimensions, cropping the excess.
- Note that resize_mode `fill` and `crop` are only supported for PIL image input.
+ The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
+ the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
+ resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
+ center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
+ image to fit within the specified width and height, maintaining the aspect ratio, and then center the
+ image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
+ supported for PIL image input.
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
The crop coordinates for each image in the batch. If `None`, will not crop the image.
Returns:
- [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]:
- Output will contain the resized input image(s) and respective VAE reconstruction(s).
+ [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
+ and respective VAE reconstruction(s).
"""
# Reset attn processor, we do not want to store attn maps during inversion
self.unet.set_attn_processor(AttnProcessor())
diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
index b1f773cb864b..cfab70926a4a 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
@@ -85,25 +85,23 @@
... )
>>> pipe = pipe.to("cuda")
+
>>> def download_image(url):
... response = requests.get(url)
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
>>> image = download_image(img_url)
- >>> _ = pipe.invert(
- ... image = image,
- ... num_inversion_steps=50,
- ... skip=0.2
- ... )
+ >>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2)
>>> edited_image = pipe(
- ... editing_prompt=["tennis ball","tomato"],
- ... reverse_editing_direction=[True,False],
- ... edit_guidance_scale=[5.0,10.0],
- ... edit_threshold=[0.9,0.85],
- ).images[0]
+ ... editing_prompt=["tennis ball", "tomato"],
+ ... reverse_editing_direction=[True, False],
+ ... edit_guidance_scale=[5.0, 10.0],
+ ... edit_threshold=[0.9, 0.85],
+ ... ).images[0]
```
"""
@@ -292,9 +290,9 @@ class LEditsPPPipelineStableDiffusionXL(
"""
Pipeline for textual image editing using LEDits++ with Stable Diffusion XL.
- This model inherits from [`DiffusionPipeline`] and builds on the [`StableDiffusionXLPipeline`]. Check the superclass
- documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular
- device, etc.).
+ This model inherits from [`DiffusionPipeline`] and builds on the [`StableDiffusionXLPipeline`]. Check the
+ superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a
+ particular device, etc.).
In addition the pipeline inherits the following loading methods:
- *LoRA*: [`LEditsPPPipelineStableDiffusionXL.load_lora_weights`]
@@ -325,8 +323,8 @@ class LEditsPPPipelineStableDiffusionXL(
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
- [`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]. If any other scheduler is passed it will automatically
- be set to [`DPMSolverMultistepScheduler`].
+ [`DPMSolverMultistepScheduler`] or [`DDIMScheduler`]. If any other scheduler is passed it will
+ automatically be set to [`DPMSolverMultistepScheduler`].
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
`stabilityai/stable-diffusion-xl-base-1-0`.
@@ -453,9 +451,9 @@ def encode_prompt(
Editing prompt(s) to be encoded. If not defined and 'enable_edit_guidance' is True, one has to pass
`editing_prompt_embeds` instead.
editing_prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
- weighting. If not provided and 'enable_edit_guidance' is True, editing_prompt_embeds will be generated from `editing_prompt` input
- argument.
+ Pre-generated edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided and 'enable_edit_guidance' is True, editing_prompt_embeds will be generated from
+ `editing_prompt` input argument.
editing_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated edit pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled editing_pooled_prompt_embeds will be generated from `editing_prompt`
@@ -835,8 +833,9 @@ def __call__(
**kwargs,
):
r"""
- The call function to the pipeline for editing. The [`~pipelines.ledits_pp.LEditsPPPipelineStableDiffusionXL.invert`]
- method has to be called beforehand. Edits will always be performed for the last inverted image(s).
+ The call function to the pipeline for editing. The
+ [`~pipelines.ledits_pp.LEditsPPPipelineStableDiffusionXL.invert`] method has to be called beforehand. Edits
+ will always be performed for the last inverted image(s).
Args:
denoising_end (`float`, *optional*):
@@ -894,11 +893,11 @@ def __call__(
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
editing_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. The image is reconstructed by setting
- `editing_prompt = None`. Guidance direction of prompt should be specified via `reverse_editing_direction`.
+ `editing_prompt = None`. Guidance direction of prompt should be specified via
+ `reverse_editing_direction`.
editing_prompt_embeddings (`torch.Tensor`, *optional*):
- Pre-generated edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
- weighting. If not provided, editing_prompt_embeddings will be generated from `editing_prompt` input
- argument.
+ Pre-generated edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, editing_prompt_embeddings will be generated from `editing_prompt` input argument.
editing_pooled_prompt_embeddings (`torch.Tensor`, *optional*):
Pre-generated pooled edit text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, editing_prompt_embeddings will be generated from `editing_prompt` input
@@ -906,35 +905,36 @@ def __call__(
reverse_editing_direction (`bool` or `List[bool]`, *optional*, defaults to `False`):
Whether the corresponding prompt in `editing_prompt` should be increased or decreased.
edit_guidance_scale (`float` or `List[float]`, *optional*, defaults to 5):
- Guidance scale for guiding the image generation. If provided as list values should correspond to `editing_prompt`.
- `edit_guidance_scale` is defined as `s_e` of equation 12 of
- [LEDITS++ Paper](https://arxiv.org/abs/2301.12247).
+ Guidance scale for guiding the image generation. If provided as list values should correspond to
+ `editing_prompt`. `edit_guidance_scale` is defined as `s_e` of equation 12 of [LEDITS++
+ Paper](https://arxiv.org/abs/2301.12247).
edit_warmup_steps (`float` or `List[float]`, *optional*, defaults to 10):
Number of diffusion steps (for each prompt) for which guidance is not applied.
edit_cooldown_steps (`float` or `List[float]`, *optional*, defaults to `None`):
Number of diffusion steps (for each prompt) after which guidance is no longer applied.
edit_threshold (`float` or `List[float]`, *optional*, defaults to 0.9):
Masking threshold of guidance. Threshold should be proportional to the image region that is modified.
- 'edit_threshold' is defined as 'λ' of equation 12 of [LEDITS++ Paper](https://arxiv.org/abs/2301.12247).
+ 'edit_threshold' is defined as 'λ' of equation 12 of [LEDITS++
+ Paper](https://arxiv.org/abs/2301.12247).
sem_guidance (`List[torch.Tensor]`, *optional*):
List of pre-generated guidance vectors to be applied at generation. Length of the list has to
correspond to `num_inference_steps`.
use_cross_attn_mask:
Whether cross-attention masks are used. Cross-attention masks are always used when use_intersect_mask
- is set to true. Cross-attention masks are defined as 'M^1' of equation 12 of
- [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf).
+ is set to true. Cross-attention masks are defined as 'M^1' of equation 12 of [LEDITS++
+ paper](https://arxiv.org/pdf/2311.16711.pdf).
use_intersect_mask:
- Whether the masking term is calculated as intersection of cross-attention masks and masks derived
- from the noise estimate. Cross-attention mask are defined as 'M^1' and masks derived from the noise
- estimate are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf).
+ Whether the masking term is calculated as intersection of cross-attention masks and masks derived from
+ the noise estimate. Cross-attention mask are defined as 'M^1' and masks derived from the noise estimate
+ are defined as 'M^2' of equation 12 of [LEDITS++ paper](https://arxiv.org/pdf/2311.16711.pdf).
user_mask:
- User-provided mask for even better control over the editing process. This is helpful when LEDITS++'s implicit
- masks do not meet user preferences.
+ User-provided mask for even better control over the editing process. This is helpful when LEDITS++'s
+ implicit masks do not meet user preferences.
attn_store_steps:
Steps for which the attention maps are stored in the AttentionStore. Just for visualization purposes.
store_averaged_over_steps:
- Whether the attention maps for the 'attn_store_steps' are stored averaged over the diffusion steps.
- If False, attention maps for each step are stores separately. Just for visualization purposes.
+ Whether the attention maps for the 'attn_store_steps' are stored averaged over the diffusion steps. If
+ False, attention maps for each step are stores separately. Just for visualization purposes.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
@@ -952,8 +952,8 @@ def __call__(
Returns:
[`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] or `tuple`:
- [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] if `return_dict` is True,
- otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ [`~pipelines.ledits_pp.LEditsPPDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
+ returning a tuple, the first element is a list with the generated images.
"""
if self.inversion_steps is None:
raise ValueError(
@@ -1446,9 +1446,9 @@ def invert(
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
r"""
- The function to the pipeline for image inversion as described by the [LEDITS++ Paper](https://arxiv.org/abs/2301.12247).
- If the scheduler is set to [`~schedulers.DDIMScheduler`] the inversion proposed by [edit-friendly DPDM](https://arxiv.org/abs/2304.06140)
- will be performed instead.
+ The function to the pipeline for image inversion as described by the [LEDITS++
+ Paper](https://arxiv.org/abs/2301.12247). If the scheduler is set to [`~schedulers.DDIMScheduler`] the
+ inversion proposed by [edit-friendly DPDM](https://arxiv.org/abs/2304.06140) will be performed instead.
Args:
image (`PipelineImageInput`):
@@ -1472,8 +1472,8 @@ def invert(
Portion of initial steps that will be ignored for inversion and subsequent generation. Lower values
will lead to stronger changes to the input image. `skip` has to be between `0` and `1`.
generator (`torch.Generator`, *optional*):
- A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
- inversion deterministic.
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make inversion
+ deterministic.
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
@@ -1488,8 +1488,8 @@ def invert(
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Returns:
- [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]:
- Output will contain the resized input image(s) and respective VAE reconstruction(s).
+ [`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
+ and respective VAE reconstruction(s).
"""
# Reset attn processor, we do not want to store attn maps during inversion
diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_output.py b/src/diffusers/pipelines/ledits_pp/pipeline_output.py
index b90005c97c4a..756be82b0069 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_output.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_output.py
@@ -35,8 +35,8 @@ class LEditsPPInversionPipelineOutput(BaseOutput):
List of the cropped and resized input images as PIL images of length `batch_size` or NumPy array of shape `
(batch_size, height, width, num_channels)`.
vae_reconstruction_images (`List[PIL.Image.Image]` or `np.ndarray`)
- List of VAE reconstruction of all input images as PIL images of length `batch_size` or NumPy array of shape `
- (batch_size, height, width, num_channels)`.
+ List of VAE reconstruction of all input images as PIL images of length `batch_size` or NumPy array of shape
+ ` (batch_size, height, width, num_channels)`.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
index 5fde3450b9a0..2bd828f0df24 100644
--- a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
+++ b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
@@ -363,8 +363,8 @@ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, devic
shape = (
batch_size,
num_channels_latents,
- height // self.vae_scale_factor,
- self.vocoder.config.model_in_dim // self.vae_scale_factor,
+ int(height) // self.vae_scale_factor,
+ int(self.vocoder.config.model_in_dim) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
index 8a24f134e793..263d507bbc75 100644
--- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
+++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
@@ -283,7 +283,12 @@ def check_inputs(self, image, height, width, callback_steps):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py
index bf55ac9f49bb..acd8659d2a7c 100644
--- a/src/diffusers/pipelines/pia/pipeline_pia.py
+++ b/src/diffusers/pipelines/pia/pipeline_pia.py
@@ -59,6 +59,7 @@
... PIAPipeline,
... )
>>> from diffusers.utils import export_to_gif, load_image
+
>>> adapter = MotionAdapter.from_pretrained("../checkpoints/pia-diffusers")
>>> pipe = PIAPipeline.from_pretrained("SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter)
>>> pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
@@ -135,9 +136,9 @@ class PIAPipelineOutput(BaseOutput):
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
- Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`,
- NumPy array of shape `(batch_size, num_frames, channels, height, width,
- Torch tensor of shape `(batch_size, num_frames, channels, height, width)`.
+ Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`, NumPy array of
+ shape `(batch_size, num_frames, channels, height, width, Torch tensor of shape `(batch_size, num_frames,
+ channels, height, width)`.
"""
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
@@ -759,16 +760,15 @@ def __call__(
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
motion_scale: (`int`, *optional*, defaults to 0):
- Parameter that controls the amount and type of motion that is added to the image. Increasing the value increases the amount of motion, while specific
- ranges of values control the type of motion that is added. Must be between 0 and 8.
- Set between 0-2 to only increase the amount of motion.
- Set between 3-5 to create looping motion.
- Set between 6-8 to perform motion with image style transfer.
+ Parameter that controls the amount and type of motion that is added to the image. Increasing the value
+ increases the amount of motion, while specific ranges of values control the type of motion that is
+ added. Must be between 0 and 8. Set between 0-2 to only increase the amount of motion. Set between 3-5
+ to create looping motion. Set between 6-8 to perform motion with image style transfer.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
@@ -795,8 +795,8 @@ def __call__(
Returns:
[`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] or `tuple`:
- If `return_dict` is `True`, [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] is
- returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
+ If `return_dict` is `True`, [`~pipelines.pia.pipeline_pia.PIAPipelineOutput`] is returned, otherwise a
+ `tuple` is returned where the first element is a list with the generated frames.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py
index 11b2c549096b..15fb34e72d24 100644
--- a/src/diffusers/pipelines/pipeline_loading_utils.py
+++ b/src/diffusers/pipelines/pipeline_loading_utils.py
@@ -22,15 +22,19 @@
from typing import Any, Dict, List, Optional, Union
import torch
-from huggingface_hub import (
- model_info,
-)
+from huggingface_hub import model_info
+from huggingface_hub.utils import validate_hf_hub_args
from packaging import version
+from .. import __version__
from ..utils import (
+ FLAX_WEIGHTS_NAME,
+ ONNX_EXTERNAL_WEIGHTS_NAME,
+ ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
get_class_from_dynamic_module,
+ is_accelerate_available,
is_peft_available,
is_transformers_available,
logging,
@@ -44,9 +48,12 @@
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
-from huggingface_hub.utils import validate_hf_hub_args
-from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
+if is_accelerate_available():
+ import accelerate
+ from accelerate import dispatch_model
+ from accelerate.hooks import remove_hook_from_module
+ from accelerate.utils import compute_module_sizes, get_max_memory
INDEX_FILE = "diffusion_pytorch_model.bin"
@@ -376,6 +383,209 @@ def _get_pipeline_class(
return pipeline_cls
+def _load_empty_model(
+ library_name: str,
+ class_name: str,
+ importable_classes: List[Any],
+ pipelines: Any,
+ is_pipeline_module: bool,
+ name: str,
+ torch_dtype: Union[str, torch.dtype],
+ cached_folder: Union[str, os.PathLike],
+ **kwargs,
+):
+ # retrieve class objects.
+ class_obj, _ = get_class_obj_and_candidates(
+ library_name,
+ class_name,
+ importable_classes,
+ pipelines,
+ is_pipeline_module,
+ component_name=name,
+ cache_dir=cached_folder,
+ )
+
+ if is_transformers_available():
+ transformers_version = version.parse(version.parse(transformers.__version__).base_version)
+ else:
+ transformers_version = "N/A"
+
+ # Determine library.
+ is_transformers_model = (
+ is_transformers_available()
+ and issubclass(class_obj, PreTrainedModel)
+ and transformers_version >= version.parse("4.20.0")
+ )
+ diffusers_module = importlib.import_module(__name__.split(".")[0])
+ is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
+
+ model = None
+ config_path = cached_folder
+ user_agent = {
+ "diffusers": __version__,
+ "file_type": "model",
+ "framework": "pytorch",
+ }
+
+ if is_diffusers_model:
+ # Load config and then the model on meta.
+ config, unused_kwargs, commit_hash = class_obj.load_config(
+ os.path.join(config_path, name),
+ cache_dir=cached_folder,
+ return_unused_kwargs=True,
+ return_commit_hash=True,
+ force_download=kwargs.pop("force_download", False),
+ resume_download=kwargs.pop("resume_download", False),
+ proxies=kwargs.pop("proxies", None),
+ local_files_only=kwargs.pop("local_files_only", False),
+ token=kwargs.pop("token", None),
+ revision=kwargs.pop("revision", None),
+ subfolder=kwargs.pop("subfolder", None),
+ user_agent=user_agent,
+ )
+ with accelerate.init_empty_weights():
+ model = class_obj.from_config(config, **unused_kwargs)
+ elif is_transformers_model:
+ config_class = getattr(class_obj, "config_class", None)
+ if config_class is None:
+ raise ValueError("`config_class` cannot be None. Please double-check the model.")
+
+ config = config_class.from_pretrained(
+ cached_folder,
+ subfolder=name,
+ force_download=kwargs.pop("force_download", False),
+ resume_download=kwargs.pop("resume_download", False),
+ proxies=kwargs.pop("proxies", None),
+ local_files_only=kwargs.pop("local_files_only", False),
+ token=kwargs.pop("token", None),
+ revision=kwargs.pop("revision", None),
+ user_agent=user_agent,
+ )
+ with accelerate.init_empty_weights():
+ model = class_obj(config)
+
+ if model is not None:
+ model = model.to(dtype=torch_dtype)
+ return model
+
+
+def _assign_components_to_devices(
+ module_sizes: Dict[str, float], device_memory: Dict[str, float], device_mapping_strategy: str = "balanced"
+):
+ device_ids = list(device_memory.keys())
+ device_cycle = device_ids + device_ids[::-1]
+ device_memory = device_memory.copy()
+
+ device_id_component_mapping = {}
+ current_device_index = 0
+ for component in module_sizes:
+ device_id = device_cycle[current_device_index % len(device_cycle)]
+ component_memory = module_sizes[component]
+ curr_device_memory = device_memory[device_id]
+
+ # If the GPU doesn't fit the current component offload to the CPU.
+ if component_memory > curr_device_memory:
+ device_id_component_mapping["cpu"] = [component]
+ else:
+ if device_id not in device_id_component_mapping:
+ device_id_component_mapping[device_id] = [component]
+ else:
+ device_id_component_mapping[device_id].append(component)
+
+ # Update the device memory.
+ device_memory[device_id] -= component_memory
+ current_device_index += 1
+
+ return device_id_component_mapping
+
+
+def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
+ # To avoid circular import problem.
+ from diffusers import pipelines
+
+ torch_dtype = kwargs.get("torch_dtype", torch.float32)
+
+ # Load each module in the pipeline on a meta device so that we can derive the device map.
+ init_empty_modules = {}
+ for name, (library_name, class_name) in init_dict.items():
+ if class_name.startswith("Flax"):
+ raise ValueError("Flax pipelines are not supported with `device_map`.")
+
+ # Define all importable classes
+ is_pipeline_module = hasattr(pipelines, library_name)
+ importable_classes = ALL_IMPORTABLE_CLASSES
+ loaded_sub_model = None
+
+ # Use passed sub model or load class_name from library_name
+ if name in passed_class_obj:
+ # if the model is in a pipeline module, then we load it from the pipeline
+ # check that passed_class_obj has correct parent class
+ maybe_raise_or_warn(
+ library_name,
+ library,
+ class_name,
+ importable_classes,
+ passed_class_obj,
+ name,
+ is_pipeline_module,
+ )
+ with accelerate.init_empty_weights():
+ loaded_sub_model = passed_class_obj[name]
+
+ else:
+ loaded_sub_model = _load_empty_model(
+ library_name=library_name,
+ class_name=class_name,
+ importable_classes=importable_classes,
+ pipelines=pipelines,
+ is_pipeline_module=is_pipeline_module,
+ pipeline_class=pipeline_class,
+ name=name,
+ torch_dtype=torch_dtype,
+ cached_folder=kwargs.get("cached_folder", None),
+ force_download=kwargs.get("force_download", None),
+ resume_download=kwargs.get("resume_download", None),
+ proxies=kwargs.get("proxies", None),
+ local_files_only=kwargs.get("local_files_only", None),
+ token=kwargs.get("token", None),
+ revision=kwargs.get("revision", None),
+ )
+
+ if loaded_sub_model is not None:
+ init_empty_modules[name] = loaded_sub_model
+
+ # determine device map
+ # Obtain a sorted dictionary for mapping the model-level components
+ # to their sizes.
+ module_sizes = {
+ module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
+ for module_name, module in init_empty_modules.items()
+ if isinstance(module, torch.nn.Module)
+ }
+ module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True))
+
+ # Obtain maximum memory available per device (GPUs only).
+ max_memory = get_max_memory(max_memory)
+ max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True))
+ max_memory = {k: v for k, v in max_memory.items() if k != "cpu"}
+
+ # Obtain a dictionary mapping the model-level components to the available
+ # devices based on the maximum memory and the model sizes.
+ final_device_map = None
+ if len(max_memory) > 0:
+ device_id_component_mapping = _assign_components_to_devices(
+ module_sizes, max_memory, device_mapping_strategy=device_map
+ )
+
+ # Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}`
+ final_device_map = {}
+ for device_id, components in device_id_component_mapping.items():
+ for component in components:
+ final_device_map[component] = device_id
+
+ return final_device_map
+
+
def load_sub_model(
library_name: str,
class_name: str,
@@ -493,6 +703,22 @@ def load_sub_model(
# else load from the root directory
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
+ if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
+ # remove hooks
+ remove_hook_from_module(loaded_sub_model, recurse=True)
+ needs_offloading_to_cpu = device_map[""] == "cpu"
+
+ if needs_offloading_to_cpu:
+ dispatch_model(
+ loaded_sub_model,
+ state_dict=loaded_sub_model.state_dict(),
+ device_map=device_map,
+ force_hooks=True,
+ main_device=0,
+ )
+ else:
+ dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
+
return loaded_sub_model
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index a98d736aa557..68433332546b 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -73,6 +73,7 @@
LOADABLE_CLASSES,
_fetch_class_library_tuple,
_get_custom_pipeline_class,
+ _get_final_device_map,
_get_pipeline_class,
_unwrap_model,
is_safetensors_compatible,
@@ -91,6 +92,8 @@
for library in LOADABLE_CLASSES:
LIBRARIES.append(library)
+SUPPORTED_DEVICE_MAP = ["balanced"]
+
logger = logging.get_logger(__name__)
@@ -141,6 +144,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
config_name = "model_index.json"
model_cpu_offload_seq = None
+ hf_device_map = None
_optional_components = []
_exclude_from_cpu_offload = []
_load_connected_pipes = False
@@ -389,6 +393,12 @@ def module_is_offloaded(module):
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)
+ is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
+ if is_pipeline_device_mapped:
+ raise ValueError(
+ "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
+ )
+
# Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
@@ -538,7 +548,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
allowed by Git.
custom_revision (`str`, *optional*):
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
- `revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers version.
+ `revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers
+ version.
mirror (`str`, *optional*):
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
@@ -641,18 +652,35 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
" install accelerate\n```\n."
)
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
- if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ if device_map is not None and not is_accelerate_available():
raise NotImplementedError(
- "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
- " `low_cpu_mem_usage=False`."
+ "Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
+ )
+
+ if device_map is not None and not isinstance(device_map, str):
+ raise ValueError("`device_map` must be a string.")
+
+ if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
+ raise NotImplementedError(
+ f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
)
+ if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
+ if is_accelerate_version("<", "0.28.0"):
+ raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")
+
if low_cpu_mem_usage is False and device_map is not None:
raise ValueError(
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
@@ -728,6 +756,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=custom_revision,
)
+ if device_map is not None and pipeline_class._load_connected_pipes:
+ raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
+
# DEPRECATED: To be removed in 1.0.0
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
version.parse(config_dict["_diffusers_version"]).base_version
@@ -794,17 +825,45 @@ def load_module(name, value):
# import it here to avoid circular import
from diffusers import pipelines
- # 6. Load each module in the pipeline
+ # 6. device map delegation
+ final_device_map = None
+ if device_map is not None:
+ final_device_map = _get_final_device_map(
+ device_map=device_map,
+ pipeline_class=pipeline_class,
+ passed_class_obj=passed_class_obj,
+ init_dict=init_dict,
+ library=library,
+ max_memory=max_memory,
+ torch_dtype=torch_dtype,
+ cached_folder=cached_folder,
+ force_download=force_download,
+ resume_download=resume_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ )
+
+ # 7. Load each module in the pipeline
+ current_device_map = None
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
- # 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
+ if final_device_map is not None and len(final_device_map) > 0:
+ component_device = final_device_map.get(name, None)
+ if component_device is not None:
+ current_device_map = {"": component_device}
+ else:
+ current_device_map = None
+
+ # 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
- # 6.2 Define all importable classes
+ # 7.2 Define all importable classes
is_pipeline_module = hasattr(pipelines, library_name)
importable_classes = ALL_IMPORTABLE_CLASSES
loaded_sub_model = None
- # 6.3 Use passed sub model or load class_name from library_name
+ # 7.3 Use passed sub model or load class_name from library_name
if name in passed_class_obj:
# if the model is in a pipeline module, then we load it from the pipeline
# check that passed_class_obj has correct parent class
@@ -825,7 +884,7 @@ def load_module(name, value):
torch_dtype=torch_dtype,
provider=provider,
sess_options=sess_options,
- device_map=device_map,
+ device_map=current_device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
@@ -892,7 +951,7 @@ def get_connected_passed_kwargs(prefix):
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
)
- # 7. Potentially add passed objects if expected
+ # 8. Potentially add passed objects if expected
missing_modules = set(expected_modules) - set(init_kwargs.keys())
passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components
@@ -905,11 +964,13 @@ def get_connected_passed_kwargs(prefix):
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)
- # 8. Instantiate the pipeline
+ # 10. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
- # 9. Save where the model was instantiated from
+ # 11. Save where the model was instantiated from
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
+ if device_map is not None:
+ setattr(model, "hf_device_map", final_device_map)
return model
@property
@@ -962,6 +1023,12 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
+ is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
+ if is_pipeline_device_mapped:
+ raise ValueError(
+ "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
+ )
+
if self.model_cpu_offload_seq is None:
raise ValueError(
"Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
@@ -997,6 +1064,7 @@ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[t
all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
+ self._all_hooks = []
hook = None
for model_str in self.model_cpu_offload_seq.split("->"):
model = all_model_components.pop(model_str, None)
@@ -1054,6 +1122,12 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
self.remove_all_hooks()
+ is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
+ if is_pipeline_device_mapped:
+ raise ValueError(
+ "It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
+ )
+
torch_device = torch.device(device)
device_index = torch_device.index
@@ -1088,6 +1162,19 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
offload_buffers = len(model._parameters) > 0
cpu_offload(model, device, offload_buffers=offload_buffers)
+ def reset_device_map(self):
+ r"""
+ Resets the device maps (if any) to None.
+ """
+ if self.hf_device_map is None:
+ return
+ else:
+ self.remove_all_hooks()
+ for name, component in self.components.items():
+ if isinstance(component, torch.nn.Module):
+ component.to("cpu")
+ self.hf_device_map = None
+
@classmethod
@validate_hf_hub_args
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
@@ -1669,7 +1756,8 @@ def set_attention_slice(self, slice_size: Optional[int]):
@classmethod
def from_pipe(cls, pipeline, **kwargs):
r"""
- Create a new pipeline from a given pipeline. This method is useful to create a new pipeline from the existing pipeline components without reallocating additional memory.
+ Create a new pipeline from a given pipeline. This method is useful to create a new pipeline from the existing
+ pipeline components without reallocating additional memory.
Arguments:
pipeline (`DiffusionPipeline`):
@@ -1728,7 +1816,7 @@ def from_pipe(cls, pipeline, **kwargs):
):
original_class_obj[name] = component
else:
- logger.warn(
+ logger.warning(
f"component {name} is not switched over to new pipeline because type does not match the expected."
f" {name} is {type(component)} while the new pipeline expect {component_types[name]}."
f" please pass the component of the correct type to the new pipeline. `from_pipe(..., {name}={name})`"
@@ -1851,8 +1939,8 @@ def disable_freeu(self):
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
index e7213a38bcad..11f7516fac7b 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
@@ -186,8 +186,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -653,7 +653,12 @@ def _clean_caption(self, caption):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
index 96873423faeb..fe83a860aeac 100644
--- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
+++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
@@ -191,7 +191,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
index cb388edec973..9db92cbdd181 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
@@ -334,8 +334,8 @@ def __call__(
argument.
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
- weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input
- argument.
+ weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt`
+ input argument.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
index ecc92bbb8819..d27e727231c9 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
@@ -31,7 +31,10 @@
```py
>>> import torch
>>> from diffusers import StableCascadeCombinedPipeline
- >>> pipe = StableCascadeCombinedPipeline.from_pretrained("stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16)
+
+ >>> pipe = StableCascadeCombinedPipeline.from_pretrained(
+ ... "stabilityai/stable-cascade", variant="bf16", torch_dtype=torch.bfloat16
+ ... )
>>> pipe.enable_model_cpu_offload()
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> images = pipe(prompt=prompt)
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
index df2581a3ebeb..ce17294a257d 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
@@ -80,7 +80,8 @@ class StableCascadePriorPipeline(DiffusionPipeline):
prior ([`StableCascadeUNet`]):
The Stable Cascade prior to approximate the image embedding from the text and/or image embedding.
text_encoder ([`CLIPTextModelWithProjection`]):
- Frozen text-encoder ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
+ Frozen text-encoder
+ ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
feature_extractor ([`~transformers.CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
image_encoder ([`CLIPVisionModelWithProjection`]):
@@ -420,11 +421,11 @@ def __call__(
argument.
negative_prompt_embeds_pooled (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
- weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt` input
- argument.
+ weighting. If not provided, negative_prompt_embeds_pooled will be generated from `negative_prompt`
+ input argument.
image_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated image embeddings. Can be used to easily tweak image inputs, *e.g.* prompt weighting.
- If not provided, image embeddings will be generated from `image` input argument if existing.
+ Pre-generated image embeddings. Can be used to easily tweak image inputs, *e.g.* prompt weighting. If
+ not provided, image embeddings will be generated from `image` input argument if existing.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -452,9 +453,9 @@ def __call__(
Examples:
Returns:
- [`StableCascadePriorPipelineOutput`] or `tuple` [`StableCascadePriorPipelineOutput`] if
- `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
- generated image embeddings.
+ [`StableCascadePriorPipelineOutput`] or `tuple` [`StableCascadePriorPipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image
+ embeddings.
"""
# 0. Define commonly used variables
diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
index 97bf139b6a74..f04a21ef4857 100644
--- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
+++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
@@ -1153,6 +1153,8 @@ def download_from_original_stable_diffusion_ckpt(
controlnet: Optional[bool] = None,
adapter: Optional[bool] = None,
load_safety_checker: bool = True,
+ safety_checker: Optional[StableDiffusionSafetyChecker] = None,
+ feature_extractor: Optional[AutoFeatureExtractor] = None,
pipeline_class: DiffusionPipeline = None,
local_files_only=False,
vae_path=None,
@@ -1205,6 +1207,12 @@ def download_from_original_stable_diffusion_ckpt(
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
load_safety_checker (`bool`, *optional*, defaults to `True`):
Whether to load the safety checker or not. Defaults to `True`.
+ safety_checker (`StableDiffusionSafetyChecker`, *optional*, defaults to `None`):
+ Safety checker to use. If this parameter is `None`, the function will load a new instance of
+ [StableDiffusionSafetyChecker] by itself, if needed.
+ feature_extractor (`AutoFeatureExtractor`, *optional*, defaults to `None`):
+ Feature extractor to use. If this parameter is `None`, the function will load a new instance of
+ [AutoFeatureExtractor] by itself, if needed.
pipeline_class (`str`, *optional*, defaults to `None`):
The pipeline class to use. Pass `None` to determine automatically.
local_files_only (`bool`, *optional*, defaults to `False`):
@@ -1530,8 +1538,8 @@ def download_from_original_stable_diffusion_ckpt(
unet=unet,
scheduler=scheduler,
controlnet=controlnet,
- safety_checker=None,
- feature_extractor=None,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
)
if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False
@@ -1551,8 +1559,8 @@ def download_from_original_stable_diffusion_ckpt(
unet=unet,
scheduler=scheduler,
low_res_scheduler=low_res_scheduler,
- safety_checker=None,
- feature_extractor=None,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
)
else:
@@ -1562,8 +1570,8 @@ def download_from_original_stable_diffusion_ckpt(
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
- safety_checker=None,
- feature_extractor=None,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
)
if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False
@@ -1684,9 +1692,6 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
)
- else:
- safety_checker = None
- feature_extractor = None
if controlnet:
pipe = pipeline_class(
@@ -1838,6 +1843,8 @@ def download_controlnet_from_original_ckpt(
while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
+ with open(original_config_file, "r") as f:
+ original_config_file = f.read()
original_config = yaml.safe_load(original_config_file)
if num_in_channels is not None:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index 0262105f43a7..d8da3b0c4bee 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -89,8 +89,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -657,7 +657,12 @@ def check_inputs(
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -806,10 +811,10 @@ def __call__(
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
index 72b438cd3325..1f822971568f 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
@@ -548,8 +548,15 @@ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_gui
pixel_values = pixel_values.to(device=device)
# The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16.
# So we use `torch.autocast` here for half precision inference.
- context_manger = torch.autocast("cuda", dtype=dtype) if device.type == "cuda" else contextlib.nullcontext()
- with context_manger:
+ if torch.backends.mps.is_available():
+ autocast_ctx = contextlib.nullcontext()
+ logger.warning(
+ "The DPT-Hybrid model uses batch-norm layers which are not compatible with fp16, but autocast is not yet supported on MPS."
+ )
+ else:
+ autocast_ctx = torch.autocast(device.type, dtype=dtype)
+
+ with autocast_ctx:
depth_map = self.depth_estimator(pixel_values).predicted_depth
else:
depth_map = depth_map.to(device=device, dtype=dtype)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
index afd872904750..c300c7a2f3f4 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
@@ -224,7 +224,12 @@ def check_inputs(self, image, height, width, callback_steps):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
index 540eed6ebd56..1b31c099b177 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -125,8 +125,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -897,10 +897,10 @@ def __call__(
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index f5bef3f5f14e..636ef5a562ff 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -189,8 +189,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -796,7 +796,12 @@ def prepare_latents(
return_noise=False,
return_image_latents=False,
):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -1023,11 +1028,12 @@ def __call__(
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
padding_mask_crop (`int`, *optional*, defaults to `None`):
- The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
- `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
- contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
- the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
- and contain information irrelevant for inpainting, such as background.
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
+ the image is large and contain information irrelevant for inpainting, such as background.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
@@ -1067,10 +1073,10 @@ def __call__(
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
index 01c2eaea062d..de2767e23952 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
@@ -730,7 +730,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
index 32b43a8e7f7f..02ddc65c7111 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
@@ -876,7 +876,12 @@ def __call__(
# 11. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
latents = self.prepare_latents(
shape=shape,
dtype=prompt_embeds.dtype,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
index 9b85d9e6b1a4..fe19b4de3127 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
@@ -543,7 +543,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
index 8e43676494a5..2adcb0a8c0a1 100644
--- a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
+++ b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
@@ -581,7 +581,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
index 206c3436bb3d..e89d7f77e8b1 100644
--- a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
+++ b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
@@ -740,7 +740,12 @@ def get_inverse_timesteps(self, num_inference_steps, strength, device):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py
index 6273128be2db..94043b7285c9 100644
--- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py
+++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py
@@ -476,7 +476,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py
index 3570cdce99bc..c20e940b4db6 100644
--- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py
+++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py
@@ -500,7 +500,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
index bc565c938a30..e2096be7e894 100755
--- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
@@ -441,7 +441,12 @@ def check_inputs(
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
index ed46a1e36b60..3cfda4064d13 100644
--- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
@@ -497,7 +497,12 @@ def check_inputs(
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
index 170551312782..880d6dcf401a 100644
--- a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
+++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
@@ -90,8 +90,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -627,7 +627,12 @@ def check_inputs(
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -773,10 +778,10 @@ def __call__(
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
index cd5189b85e68..514c8643ba36 100644
--- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
+++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
@@ -90,8 +90,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -638,7 +638,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -694,9 +699,9 @@ def get_views(
circular_padding: bool = False,
) -> List[Tuple[int, int, int, int]]:
"""
- Generates a list of views based on the given parameters.
- Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113).
- If panorama's height/width < window_size, num_blocks of height/width should return 1.
+ Generates a list of views based on the given parameters. Here, we define the mappings F_i (see Eq. 7 in the
+ MultiDiffusion paper https://arxiv.org/abs/2302.08113). If panorama's height/width < window_size, num_blocks of
+ height/width should return 1.
Args:
panorama_height (int): The height of the panorama.
@@ -706,8 +711,8 @@ def get_views(
circular_padding (bool, optional): Whether to apply circular padding. Defaults to False.
Returns:
- List[Tuple[int, int, int, int]]: A list of tuples representing the views. Each tuple contains
- four integers representing the start and end coordinates of the window in the panorama.
+ List[Tuple[int, int, int, int]]: A list of tuples representing the views. Each tuple contains four integers
+ representing the start and end coordinates of the window in the panorama.
"""
panorama_height /= 8
@@ -800,8 +805,8 @@ def __call__(
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
- The timesteps at which to generate the images. If not specified, then the default
- timestep spacing strategy of the scheduler is used.
+ The timesteps at which to generate the images. If not specified, then the default timestep spacing
+ strategy of the scheduler is used.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -832,10 +837,10 @@ def __call__(
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
index ae74e09678e3..63b8c6108ac4 100644
--- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
+++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
@@ -416,7 +416,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
index 2e7a1fa41b58..cb29ce386f32 100644
--- a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
+++ b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
@@ -535,7 +535,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -619,8 +624,8 @@ def __call__(
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. If not
- provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. If not provided, embeddings are computed from the
+ `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
index 8d1646e4d887..5af29b5719ce 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
@@ -117,8 +117,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -685,7 +685,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -919,10 +924,10 @@ def __call__(
input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1199,10 +1204,6 @@ def __call__(
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
- else:
- raise ValueError(
- "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
- )
if callback_on_step_end is not None:
callback_kwargs = {}
@@ -1241,10 +1242,6 @@ def __call__(
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)
- else:
- raise ValueError(
- "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
- )
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
index af9da5073e06..b72b19d5c1ef 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
@@ -134,8 +134,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -665,6 +665,12 @@ def prepare_latents(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
+ latents_mean = latents_std = None
+ if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
+ latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
+
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
@@ -702,7 +708,12 @@ def prepare_latents(
self.vae.to(dtype)
init_latents = init_latents.to(dtype)
- init_latents = self.vae.config.scaling_factor * init_latents
+ if latents_mean is not None and latents_std is not None:
+ latents_mean = latents_mean.to(device=self.device, dtype=dtype)
+ latents_std = latents_std.to(device=self.device, dtype=dtype)
+ init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
+ else:
+ init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
@@ -1067,10 +1078,10 @@ def __call__(
input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1376,10 +1387,6 @@ def denoising_value_valid(dnv):
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
- else:
- raise ValueError(
- "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
- )
if callback_on_step_end is not None:
callback_kwargs = {}
@@ -1418,10 +1425,6 @@ def denoising_value_valid(dnv):
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)
- else:
- raise ValueError(
- "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
- )
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index c9a72ccda985..26c1484e9ae2 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -279,8 +279,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -880,7 +880,12 @@ def prepare_latents(
return_noise=False,
return_image_latents=False,
):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -1255,11 +1260,12 @@ def __call__(
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
and checkpoints that are not specifically fine-tuned on low resolutions.
padding_mask_crop (`int`, *optional*, defaults to `None`):
- The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
- `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
- contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
- the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
- and contain information irrelevant for inpainting, such as background.
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
+ the image is large and contain information irrelevant for inpainting, such as background.
strength (`float`, *optional*, defaults to 0.9999):
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
@@ -1319,10 +1325,10 @@ def __call__(
input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
@@ -1726,10 +1732,6 @@ def denoising_value_valid(dnv):
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
- else:
- raise ValueError(
- "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
- )
if num_channels_unet == 4:
init_latents_proper = image_latents
@@ -1785,10 +1787,6 @@ def denoising_value_valid(dnv):
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)
- else:
- raise ValueError(
- "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
- )
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
index 9aedb8667587..a2242bb099c5 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
@@ -169,6 +169,8 @@ class StableDiffusionXLInstructPix2PixPipeline(
Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
watermarker will be used.
+ is_cosxl_edit (`bool`, *optional*):
+ When set the image latents are scaled.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
@@ -185,6 +187,7 @@ def __init__(
scheduler: KarrasDiffusionSchedulers,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
+ is_cosxl_edit: Optional[bool] = False,
):
super().__init__()
@@ -201,6 +204,7 @@ def __init__(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
+ self.is_cosxl_edit = is_cosxl_edit
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -483,7 +487,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -551,6 +560,9 @@ def prepare_image_latents(
if image_latents.dtype != self.vae.dtype:
image_latents = image_latents.to(dtype=self.vae.dtype)
+ if self.is_cosxl_edit:
+ image_latents = image_latents * self.vae.config.scaling_factor
+
return image_latents
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
@@ -924,10 +936,6 @@ def __call__(
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
- else:
- raise ValueError(
- "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
- )
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -950,10 +958,6 @@ def __call__(
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)
- else:
- raise ValueError(
- "For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
- )
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
index 1342fe429145..070183b92409 100644
--- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
+++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
@@ -37,10 +37,14 @@
>>> from diffusers import StableVideoDiffusionPipeline
>>> from diffusers.utils import load_image, export_to_video
- >>> pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
+ >>> pipe = StableVideoDiffusionPipeline.from_pretrained(
+ ... "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
+ ... )
>>> pipe.to("cuda")
- >>> image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg")
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg"
+ ... )
>>> image = image.resize((1024, 576))
>>> frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
@@ -86,8 +90,8 @@ class StableVideoDiffusionPipelineOutput(BaseOutput):
Args:
frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.FloatTensor`]):
- List of denoised PIL images of length `batch_size` or numpy array or torch tensor
- of shape `(batch_size, num_frames, height, width, num_channels)`.
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
+ num_frames, height, width, num_channels)`.
"""
frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.FloatTensor]
@@ -104,7 +108,8 @@ class StableVideoDiffusionPipeline(DiffusionPipeline):
vae ([`AutoencoderKLTemporalDecoder`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
- Frozen CLIP image-encoder ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
+ Frozen CLIP image-encoder
+ ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
unet ([`UNetSpatioTemporalConditionModel`]):
A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
scheduler ([`EulerDiscreteScheduler`]):
@@ -357,14 +362,15 @@ def __call__(
Args:
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
- Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0, 1]`.
+ Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0,
+ 1]`.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
num_frames (`int`, *optional*):
- The number of video frames to generate. Defaults to `self.unet.config.num_frames`
- (14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
+ The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for
+ `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
num_inference_steps (`int`, *optional*, defaults to 25):
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
expense of slower inference. This parameter is modulated by `strength`.
@@ -373,16 +379,18 @@ def __call__(
max_guidance_scale (`float`, *optional*, defaults to 3.0):
The maximum guidance scale. Used for the classifier free guidance with last frame.
fps (`int`, *optional*, defaults to 7):
- Frames per second. The rate at which the generated images shall be exported to a video after generation.
- Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
+ Frames per second. The rate at which the generated images shall be exported to a video after
+ generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
motion_bucket_id (`int`, *optional*, defaults to 127):
Used for conditioning the amount of motion for the generation. The higher the number the more motion
will be in the video.
noise_aug_strength (`float`, *optional*, defaults to 0.02):
- The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
+ The amount of noise added to the init image, the higher it is the less the video will look like the
+ init image. Increase it for more motion.
decode_chunk_size (`int`, *optional*):
- The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the expense of more memory usage. By default, the decoder decodes all frames at once for maximal
- quality. For lower memory usage, reduce `decode_chunk_size`.
+ The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
+ expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
+ For lower memory usage, reduce `decode_chunk_size`.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -398,7 +406,8 @@ def __call__(
A function that is called at the end of each denoising step during inference. The function is called
with the following arguments:
`callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
- `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
@@ -411,8 +420,9 @@ def __call__(
Returns:
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
- If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
- otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.FloatTensor`) is returned.
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is
+ returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.FloatTensor`)
+ is returned.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
@@ -627,7 +637,7 @@ def _filter2d(input, kernel):
height, width = tmp_kernel.shape[-2:]
- padding_shape: list[int] = _compute_padding([height, width])
+ padding_shape: List[int] = _compute_padding([height, width])
input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
# kernel and input tensor reshape to align element-wise or batch-wise params
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
index 10f8dc66f79d..fd5b1bb454b7 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
@@ -134,8 +134,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -569,7 +569,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
index 59d4022923eb..fdf938b248e6 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
@@ -150,8 +150,8 @@ def retrieve_timesteps(
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
- `timesteps` must be `None`.
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
@@ -700,7 +700,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -943,10 +948,10 @@ def __call__(
input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
- Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
- if `do_classifier_free_guidance` is set to `True`.
- If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_output.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_output.py
index c155386cf173..2dae5b4ead69 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_output.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_output.py
@@ -17,7 +17,8 @@ class TextToVideoSDPipelineOutput(BaseOutput):
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
- List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing denoised
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised
PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`
"""
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
index 5ac211eef80f..dddd65079625 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
@@ -495,7 +495,12 @@ def check_inputs(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
index 07d7e92e11d9..2dbd928b916e 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
@@ -471,7 +471,12 @@ def _get_add_time_ids(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py
index c074b9916301..6579e272a3bf 100644
--- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py
+++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py
@@ -752,7 +752,8 @@ def forward(
cross_attention_kwargs (*optional*):
Keyword arguments to supply to the cross attention layers, if used.
return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+ Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
hidden_states_is_embedding (`bool`, *optional*, defaults to `False`):
Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will
ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the
diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py
index dc3d8455bdfe..23c71a61452a 100644
--- a/src/diffusers/schedulers/scheduling_ddim_flax.py
+++ b/src/diffusers/schedulers/scheduling_ddim_flax.py
@@ -85,7 +85,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
trained_betas (`jnp.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
clip_sample (`bool`, default `True`):
- option to clip predicted sample between for numerical stability. The clip range is determined by `clip_sample_range`.
+ option to clip predicted sample between for numerical stability. The clip range is determined by
+ `clip_sample_range`.
clip_sample_range (`float`, default `1.0`):
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
set_alpha_to_one (`bool`, default `True`):
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
index 5b452bddba70..7e0939e0d927 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
@@ -166,8 +166,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
`lambda(t)`.
final_sigmas_type (`str`, defaults to `"zero"`):
- The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
- is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
index 215d94a7863f..d7a073c2383e 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
@@ -108,11 +108,11 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`.
algorithm_type (`str`, defaults to `dpmsolver++`):
- Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The
- `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
- paper, and the `dpmsolver++` type implements the algorithms in the
- [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
- `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
+ Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the
+ algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type
+ implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is
+ recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in
+ Stable Diffusion.
solver_type (`str`, defaults to `midpoint`):
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
@@ -123,8 +123,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
- The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
- is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
index 9422d57cff89..26a41d7335c5 100644
--- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
@@ -62,10 +62,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`.
algorithm_type (`str`, defaults to `dpmsolver++`):
- Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. The
- `dpmsolver++` type implements the algorithms in the
- [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
- `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
+ Algorithm type for the solver; can be `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver++` type implements
+ the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to
+ use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
solver_type (`str`, defaults to `midpoint`):
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
@@ -77,8 +76,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
steps, but sometimes may result in blurring.
final_sigmas_type (`str`, defaults to `"zero"`):
- The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
- is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
"""
_compatibles = []
diff --git a/src/diffusers/schedulers/scheduling_edm_euler.py b/src/diffusers/schedulers/scheduling_edm_euler.py
index bad6aeff8b62..f6a09ca1ee16 100644
--- a/src/diffusers/schedulers/scheduling_edm_euler.py
+++ b/src/diffusers/schedulers/scheduling_edm_euler.py
@@ -278,8 +278,7 @@ def step(
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
- Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or
- tuple.
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`:
diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py
index a80cc66a393d..b8d95c609bf1 100644
--- a/src/diffusers/schedulers/scheduling_sasolver.py
+++ b/src/diffusers/schedulers/scheduling_sasolver.py
@@ -92,19 +92,20 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
predictor_order (`int`, defaults to 2):
- The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for guided
- sampling, and `predictor_order=3` for unconditional sampling.
+ The predictor order which can be `1` or `2` or `3` or '4'. It is recommended to use `predictor_order=2` for
+ guided sampling, and `predictor_order=3` for unconditional sampling.
corrector_order (`int`, defaults to 2):
- The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for guided
- sampling, and `corrector_order=3` for unconditional sampling.
+ The corrector order which can be `1` or `2` or `3` or '4'. It is recommended to use `corrector_order=2` for
+ guided sampling, and `corrector_order=3` for unconditional sampling.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://imagen.research.google/video/paper.pdf) paper).
tau_func (`Callable`, *optional*):
- Stochasticity during the sampling. Default in init is `lambda t: 1 if t >= 200 and t <= 800 else 0`. SA-Solver
- will sample from vanilla diffusion ODE if tau_func is set to `lambda t: 0`. SA-Solver will sample from vanilla
- diffusion SDE if tau_func is set to `lambda t: 1`. For more details, please check https://arxiv.org/abs/2309.05019
+ Stochasticity during the sampling. Default in init is `lambda t: 1 if t >= 200 and t <= 800 else 0`.
+ SA-Solver will sample from vanilla diffusion ODE if tau_func is set to `lambda t: 0`. SA-Solver will sample
+ from vanilla diffusion SDE if tau_func is set to `lambda t: 1`. For more details, please check
+ https://arxiv.org/abs/2309.05019
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
as Stable Diffusion.
@@ -114,8 +115,8 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
`algorithm_type="dpmsolver++"`.
algorithm_type (`str`, defaults to `data_prediction`):
- Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use `data_prediction`
- with `solver_order=2` for guided sampling like in Stable Diffusion.
+ Algorithm type for the solver; can be `data_prediction` or `noise_prediction`. It is recommended to use
+ `data_prediction` with `solver_order=2` for guided sampling like in Stable Diffusion.
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Default = True.
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
@@ -402,14 +403,14 @@ def convert_model_output(
**kwargs,
) -> torch.FloatTensor:
"""
- Convert the model output to the corresponding type the data_prediction/noise_prediction algorithm needs. Noise_prediction is
- designed to discretize an integral of the noise prediction model, and data_prediction is designed to discretize an
- integral of the data prediction model.
+ Convert the model output to the corresponding type the data_prediction/noise_prediction algorithm needs.
+ Noise_prediction is designed to discretize an integral of the noise prediction model, and data_prediction is
+ designed to discretize an integral of the data prediction model.
- The algorithm and model type are decoupled. You can use either data_prediction or noise_prediction for both noise
- prediction and data prediction models.
+ The algorithm and model type are decoupled. You can use either data_prediction or noise_prediction for both
+ noise prediction and data prediction models.
diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py
index ee3cde5d2142..0216b7afc80a 100644
--- a/src/diffusers/schedulers/scheduling_tcd.py
+++ b/src/diffusers/schedulers/scheduling_tcd.py
@@ -132,8 +132,8 @@ def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
class TCDScheduler(SchedulerMixin, ConfigMixin):
"""
- `TCDScheduler` incorporates the `Strategic Stochastic Sampling` introduced by the paper `Trajectory Consistency Distillation`,
- extending the original Multistep Consistency Sampling to enable unrestricted trajectory traversal.
+ `TCDScheduler` incorporates the `Strategic Stochastic Sampling` introduced by the paper `Trajectory Consistency
+ Distillation`, extending the original Multistep Consistency Sampling to enable unrestricted trajectory traversal.
This code is based on the official repo of TCD(https://github.com/jabir-zheng/TCD).
@@ -543,8 +543,9 @@ def step(
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
- A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every step.
- When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling.
+ A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every
+ step. When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic
+ sampling.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py
index 70e63a64c0a8..74e97a33f1e2 100644
--- a/src/diffusers/schedulers/scheduling_unipc_multistep.py
+++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py
@@ -71,6 +71,43 @@ def alpha_bar_fn(t):
return torch.tensor(betas, dtype=torch.float32)
+# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.FloatTensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
`UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models.
@@ -128,8 +165,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
final_sigmas_type (`str`, defaults to `"zero"`):
- The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma
- is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
+ sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -157,6 +198,7 @@ def __init__(
timestep_spacing: str = "linspace",
steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
+ rescale_betas_zero_snr: bool = False,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -171,8 +213,17 @@ def __init__(
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ if rescale_betas_zero_snr:
+ # Close to 0 without being 0 so first sigma is not inf
+ # FP16 smallest positive subnormal works well here
+ self.alphas_cumprod[-1] = 2**-24
+
# Currently we only support VP-type noise schedule
self.alpha_t = torch.sqrt(self.alphas_cumprod)
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
@@ -576,7 +627,7 @@ def multistep_uni_p_bh_update(
if order == 2:
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
- rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
else:
D1s = None
@@ -714,7 +765,7 @@ def multistep_uni_c_bh_update(
if order == 1:
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
- rhos_c = torch.linalg.solve(R, b)
+ rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 14947848a43f..b04006cb5ee6 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -92,6 +92,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class ControlNetXSAdapter(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class I2VGenXLUNet(metaclass=DummyObject):
_backends = ["torch"]
@@ -287,6 +302,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class UNetControlNetXSModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class UNetMotionModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index f64c15702087..8ad2f4b4876d 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -902,6 +902,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class StableDiffusionControlNetXSPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1247,6 +1262,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py
index f744c1dfb1aa..add95812122c 100644
--- a/src/diffusers/utils/dynamic_modules_utils.py
+++ b/src/diffusers/utils/dynamic_modules_utils.py
@@ -246,8 +246,8 @@ def get_cached_module_file(
- You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private
- or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
+ You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
+ [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
@@ -434,8 +434,8 @@ def get_class_from_dynamic_module(
- You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private
- or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
+ You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
+ [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py
index e554b42ddd31..d70ee53aaa41 100644
--- a/src/diffusers/utils/hub_utils.py
+++ b/src/diffusers/utils/hub_utils.py
@@ -112,7 +112,8 @@ def load_or_create_model_card(
repo_id_or_path (`str`):
The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card.
token (`str`, *optional*):
- Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details.
+ Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more
+ details.
is_pipeline (`bool`):
Boolean to indicate if we're adding tag to a [`DiffusionPipeline`].
from_training: (`bool`): Boolean flag to denote if the model card is being created from a training script.
diff --git a/src/diffusers/utils/loading_utils.py b/src/diffusers/utils/loading_utils.py
index 18f6ead64c4e..aa087e981731 100644
--- a/src/diffusers/utils/loading_utils.py
+++ b/src/diffusers/utils/loading_utils.py
@@ -16,8 +16,8 @@ def load_image(
image (`str` or `PIL.Image.Image`):
The image to convert to the PIL Image format.
convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional):
- A conversion method to apply to the image after loading it.
- When set to `None` the image will be converted "RGB".
+ A conversion method to apply to the image after loading it. When set to `None` the image will be converted
+ "RGB".
Returns:
`PIL.Image.Image`:
diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py
index feececc56966..8ea12e2e3b3f 100644
--- a/src/diffusers/utils/peft_utils.py
+++ b/src/diffusers/utils/peft_utils.py
@@ -64,9 +64,11 @@ def recurse_remove_peft_layers(model):
module_replaced = False
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
- new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
- module.weight.device
- )
+ new_module = torch.nn.Linear(
+ module.in_features,
+ module.out_features,
+ bias=module.bias is not None,
+ ).to(module.weight.device)
new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias
@@ -110,6 +112,9 @@ def scale_lora_layers(model, weight):
"""
from peft.tuners.tuners_utils import BaseTunerLayer
+ if weight == 1.0:
+ return
+
for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.scale_layer(weight)
@@ -129,6 +134,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
"""
from peft.tuners.tuners_utils import BaseTunerLayer
+ if weight == 1.0:
+ return
+
for module in model.modules():
if isinstance(module, BaseTunerLayer):
if weight is not None and weight != 0:
diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py
index 35fc4210a908..dc303a35a8e3 100644
--- a/src/diffusers/utils/state_dict_utils.py
+++ b/src/diffusers/utils/state_dict_utils.py
@@ -253,8 +253,8 @@ def convert_unet_state_dict_to_peft(state_dict):
def convert_all_state_dict_to_peft(state_dict):
r"""
- Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer`
- for a valid `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
+ Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid
+ `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
"""
try:
peft_dict = convert_state_dict_to_peft(state_dict)
diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py
index 4ea541dac356..c1756d6590d1 100644
--- a/src/diffusers/utils/testing_utils.py
+++ b/src/diffusers/utils/testing_utils.py
@@ -156,8 +156,8 @@ def get_tests_dir(append_path=None):
# https://github.com/huggingface/accelerate/pull/1964
def str_to_bool(value) -> int:
"""
- Converts a string representation of truth to `True` (1) or `False` (0).
- True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
+ Converts a string representation of truth to `True` (1) or `False` (0). True values are `y`, `yes`, `t`, `true`,
+ `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
"""
value = value.lower()
if value in ("y", "yes", "t", "true", "on", "1"):
@@ -255,6 +255,20 @@ def require_torch_accelerator(test_case):
)
+def require_torch_multi_gpu(test_case):
+ """
+ Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
+ multiple GPUs. To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests
+ -k "multi_gpu"
+ """
+ if not is_torch_available():
+ return unittest.skip("test requires PyTorch")(test_case)
+
+ import torch
+
+ return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
+
+
def require_torch_accelerator_with_fp16(test_case):
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
@@ -343,6 +357,18 @@ def decorator(test_case):
return decorator
+def require_accelerate_version_greater(accelerate_version):
+ def decorator(test_case):
+ correct_accelerate_version = is_peft_available() and version.parse(
+ version.parse(importlib.metadata.version("accelerate")).base_version
+ ) > version.parse(accelerate_version)
+ return unittest.skipUnless(
+ correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
+ )(test_case)
+
+ return decorator
+
+
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend
diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py
index ebf46e396284..fc28d94c240b 100644
--- a/tests/lora/test_lora_layers_sd.py
+++ b/tests/lora/test_lora_layers_sd.py
@@ -150,6 +150,54 @@ def test_integration_move_lora_cpu(self):
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu"))
+ @require_torch_gpu
+ def test_integration_move_lora_dora_cpu(self):
+ from peft import LoraConfig
+
+ path = "runwayml/stable-diffusion-v1-5"
+ unet_lora_config = LoraConfig(
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ use_dora=True,
+ )
+ text_lora_config = LoraConfig(
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ use_dora=True,
+ )
+
+ pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
+ pipe.unet.add_adapter(unet_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder),
+ "Lora not correctly set in text encoder",
+ )
+
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.unet),
+ "Lora not correctly set in text encoder",
+ )
+
+ for name, param in pipe.unet.named_parameters():
+ if "lora_" in name:
+ self.assertEqual(param.device, torch.device("cpu"))
+
+ for name, param in pipe.text_encoder.named_parameters():
+ if "lora_" in name:
+ self.assertEqual(param.device, torch.device("cpu"))
+
+ pipe.set_lora_device(["adapter-1"], torch_device)
+
+ for name, param in pipe.unet.named_parameters():
+ if "lora_" in name:
+ self.assertNotEqual(param.device, torch.device("cpu"))
+
+ for name, param in pipe.text_encoder.named_parameters():
+ if "lora_" in name:
+ self.assertNotEqual(param.device, torch.device("cpu"))
+
@slow
@require_torch_gpu
diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py
index b0c24b8d4315..026e01f0ed6a 100644
--- a/tests/models/autoencoders/test_models_vae.py
+++ b/tests/models/autoencoders/test_models_vae.py
@@ -53,8 +53,8 @@
def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
- block_out_channels = block_out_channels or [32, 64]
- norm_num_groups = norm_num_groups or 32
+ block_out_channels = block_out_channels or [2, 4]
+ norm_num_groups = norm_num_groups or 2
init_dict = {
"block_out_channels": block_out_channels,
"in_channels": 3,
@@ -68,8 +68,8 @@ def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
def get_asym_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
- block_out_channels = block_out_channels or [32, 64]
- norm_num_groups = norm_num_groups or 32
+ block_out_channels = block_out_channels or [2, 4]
+ norm_num_groups = norm_num_groups or 2
init_dict = {
"in_channels": 3,
"out_channels": 3,
@@ -102,8 +102,8 @@ def get_autoencoder_tiny_config(block_out_channels=None):
def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None):
- block_out_channels = block_out_channels or [32, 64]
- norm_num_groups = norm_num_groups or 32
+ block_out_channels = block_out_channels or [2, 4]
+ norm_num_groups = norm_num_groups or 2
return {
"encoder_block_out_channels": block_out_channels,
"encoder_in_channels": 3,
diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py
index 8b138bf67f41..d4262e2709dc 100644
--- a/tests/models/autoencoders/test_models_vq.py
+++ b/tests/models/autoencoders/test_models_vq.py
@@ -54,7 +54,8 @@ def output_shape(self):
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
- "block_out_channels": [32, 64],
+ "block_out_channels": [8, 16],
+ "norm_num_groups": 8,
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
diff --git a/tests/models/test_attention_processor.py b/tests/models/test_attention_processor.py
index fadee4a9e337..2489604274b4 100644
--- a/tests/models/test_attention_processor.py
+++ b/tests/models/test_attention_processor.py
@@ -80,7 +80,9 @@ def test_only_cross_attention(self):
class DeprecatedAttentionBlockTests(unittest.TestCase):
def test_conversion_when_using_device_map(self):
- pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None)
+ pipe = DiffusionPipeline.from_pretrained(
+ "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
+ )
pre_conversion = pipe(
"foo",
@@ -91,7 +93,7 @@ def test_conversion_when_using_device_map(self):
# the initial conversion succeeds
pipe = DiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-pipe", device_map="sequential", safety_checker=None
+ "hf-internal-testing/tiny-stable-diffusion-torch", device_map="balanced", safety_checker=None
)
conversion = pipe(
@@ -106,8 +108,7 @@ def test_conversion_when_using_device_map(self):
pipe.save_pretrained(tmpdir)
# can also load the converted weights
- pipe = DiffusionPipeline.from_pretrained(tmpdir, device_map="sequential", safety_checker=None)
-
+ pipe = DiffusionPipeline.from_pretrained(tmpdir, device_map="balanced", safety_checker=None)
after_conversion = pipe(
"foo",
num_inference_steps=2,
@@ -115,5 +116,5 @@ def test_conversion_when_using_device_map(self):
output_type="np",
).images
- self.assertTrue(np.allclose(pre_conversion, conversion, atol=1e-5))
- self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-5))
+ self.assertTrue(np.allclose(pre_conversion, conversion, atol=1e-3))
+ self.assertTrue(np.allclose(conversion, after_conversion, atol=1e-3))
diff --git a/tests/models/unets/test_models_unet_1d.py b/tests/models/unets/test_models_unet_1d.py
index bd0506c8bcb9..9f7ef3bca085 100644
--- a/tests/models/unets/test_models_unet_1d.py
+++ b/tests/models/unets/test_models_unet_1d.py
@@ -77,7 +77,7 @@ def test_output(self):
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
- "block_out_channels": (32, 64, 128, 256),
+ "block_out_channels": (8, 8, 16, 16),
"in_channels": 14,
"out_channels": 14,
"time_embedding_type": "positional",
diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py
index 3397e2188bb9..97329898991f 100644
--- a/tests/models/unets/test_models_unet_2d.py
+++ b/tests/models/unets/test_models_unet_2d.py
@@ -63,7 +63,8 @@ def output_shape(self):
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
- "block_out_channels": (32, 64),
+ "block_out_channels": (4, 8),
+ "norm_num_groups": 2,
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
"up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
"attention_head_dim": 3,
@@ -78,9 +79,8 @@ def prepare_init_args_and_inputs_for_common(self):
def test_mid_block_attn_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- init_dict["norm_num_groups"] = 16
init_dict["add_attention"] = True
- init_dict["attn_norm_num_groups"] = 8
+ init_dict["attn_norm_num_groups"] = 4
model = self.model_class(**init_dict)
model.to(torch_device)
diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py
index a19e8f8c65c2..1b8a998cfd66 100644
--- a/tests/models/unets/test_models_unet_2d_condition.py
+++ b/tests/models/unets/test_models_unet_2d_condition.py
@@ -30,7 +30,7 @@
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
-from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
+from diffusers.models.embeddings import ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterPlusImageProjection
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
@@ -190,6 +190,64 @@ def create_ip_adapter_plus_state_dict(model):
return ip_state_dict
+def create_ip_adapter_faceid_state_dict(model):
+ # "ip_adapter" (cross-attention weights)
+ # no LoRA weights
+ ip_cross_attn_state_dict = {}
+ key_id = 1
+
+ for name in model.attn_processors.keys():
+ cross_attention_dim = (
+ None if name.endswith("attn1.processor") or "motion_module" in name else model.config.cross_attention_dim
+ )
+
+ if name.startswith("mid_block"):
+ hidden_size = model.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(model.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = model.config.block_out_channels[block_id]
+
+ if cross_attention_dim is not None:
+ sd = IPAdapterAttnProcessor(
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
+ ).state_dict()
+ ip_cross_attn_state_dict.update(
+ {
+ f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
+ f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
+ }
+ )
+
+ key_id += 2
+
+ # "image_proj" (ImageProjection layer weights)
+ cross_attention_dim = model.config["cross_attention_dim"]
+ image_projection = IPAdapterFaceIDImageProjection(
+ cross_attention_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, mult=2, num_tokens=4
+ )
+
+ ip_image_projection_state_dict = {}
+ sd = image_projection.state_dict()
+ ip_image_projection_state_dict.update(
+ {
+ "proj.0.weight": sd["ff.net.0.proj.weight"],
+ "proj.0.bias": sd["ff.net.0.proj.bias"],
+ "proj.2.weight": sd["ff.net.2.weight"],
+ "proj.2.bias": sd["ff.net.2.bias"],
+ "norm.weight": sd["norm.weight"],
+ "norm.bias": sd["norm.bias"],
+ }
+ )
+
+ del sd
+ ip_state_dict = {}
+ ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
+ return ip_state_dict
+
+
def create_custom_diffusion_layers(model, mock_weights: bool = True):
train_kv = True
train_q_out = True
@@ -247,33 +305,34 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
def dummy_input(self):
batch_size = 4
num_channels = 4
- sizes = (32, 32)
+ sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
- encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
+ encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def input_shape(self):
- return (4, 32, 32)
+ return (4, 16, 16)
@property
def output_shape(self):
- return (4, 32, 32)
+ return (4, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
- "block_out_channels": (32, 64),
+ "block_out_channels": (4, 8),
+ "norm_num_groups": 4,
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
- "cross_attention_dim": 32,
- "attention_head_dim": 8,
+ "cross_attention_dim": 8,
+ "attention_head_dim": 2,
"out_channels": 4,
"in_channels": 4,
- "layers_per_block": 2,
- "sample_size": 32,
+ "layers_per_block": 1,
+ "sample_size": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@@ -337,6 +396,7 @@ def test_gradient_checkpointing(self):
def test_model_with_attention_head_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
@@ -375,7 +435,7 @@ def test_model_with_use_linear_projection(self):
def test_model_with_cross_attention_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- init_dict["cross_attention_dim"] = (32, 32)
+ init_dict["cross_attention_dim"] = (8, 8)
model = self.model_class(**init_dict)
model.to(torch_device)
@@ -443,6 +503,7 @@ def test_model_with_class_embeddings_concat(self):
def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
@@ -467,6 +528,7 @@ def test_model_attention_slicing(self):
def test_model_sliceable_head_dim(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
@@ -485,6 +547,7 @@ def check_sliceable_dim_attr(module: torch.nn.Module):
def test_gradient_checkpointing_is_applied(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model_class_copy = copy.copy(self.model_class)
@@ -561,6 +624,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
@@ -571,7 +635,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
model.set_attn_processor(processor)
model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample
- assert processor.counter == 12
+ assert processor.counter == 8
assert processor.is_run
assert processor.number == 123
@@ -587,7 +651,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
def test_model_xattn_mask(self, mask_dtype):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
+ model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)})
model.to(torch_device)
model.eval()
@@ -649,6 +713,7 @@ def test_custom_diffusion_processors(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
@@ -675,6 +740,7 @@ def test_custom_diffusion_save_load(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0)
@@ -714,6 +780,7 @@ def test_custom_diffusion_xformers_on_off(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0)
@@ -739,6 +806,7 @@ def test_pickle(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
@@ -770,6 +838,7 @@ def test_asymmetrical_unet(self):
def test_ip_adapter(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
@@ -842,6 +911,7 @@ def test_ip_adapter(self):
def test_ip_adapter_plus(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
diff --git a/tests/models/unets/test_models_unet_3d_condition.py b/tests/models/unets/test_models_unet_3d_condition.py
index adbf89597435..e798586b6965 100644
--- a/tests/models/unets/test_models_unet_3d_condition.py
+++ b/tests/models/unets/test_models_unet_3d_condition.py
@@ -41,36 +41,37 @@ def dummy_input(self):
batch_size = 4
num_channels = 4
num_frames = 4
- sizes = (32, 32)
+ sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
- encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
+ encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def input_shape(self):
- return (4, 4, 32, 32)
+ return (4, 4, 16, 16)
@property
def output_shape(self):
- return (4, 4, 32, 32)
+ return (4, 4, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
- "block_out_channels": (32, 64),
+ "block_out_channels": (4, 8),
+ "norm_num_groups": 4,
"down_block_types": (
"CrossAttnDownBlock3D",
"DownBlock3D",
),
"up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D"),
- "cross_attention_dim": 32,
- "attention_head_dim": 8,
+ "cross_attention_dim": 8,
+ "attention_head_dim": 2,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 1,
- "sample_size": 32,
+ "sample_size": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@@ -93,7 +94,7 @@ def test_xformers_enable_works(self):
# Overriding to set `norm_num_groups` needs to be different for this model.
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
+ init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict)
@@ -140,6 +141,7 @@ def test_determinism(self):
def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = 8
model = self.model_class(**init_dict)
@@ -163,6 +165,7 @@ def test_model_attention_slicing(self):
def test_feed_forward_chunking(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict)
diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py
new file mode 100644
index 000000000000..8c9b43a20ad6
--- /dev/null
+++ b/tests/models/unets/test_models_unet_controlnetxs.py
@@ -0,0 +1,352 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import unittest
+
+import numpy as np
+import torch
+from torch import nn
+
+from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
+from diffusers.utils import logging
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ floats_tensor,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+
+
+logger = logging.get_logger(__name__)
+
+enable_full_determinism()
+
+
+class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+ model_class = UNetControlNetXSModel
+ main_input_name = "sample"
+
+ @property
+ def dummy_input(self):
+ batch_size = 4
+ num_channels = 4
+ sizes = (16, 16)
+ conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning
+
+ noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
+ time_step = torch.tensor([10]).to(torch_device)
+ encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
+ controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device)
+ conditioning_scale = 1
+
+ return {
+ "sample": noise,
+ "timestep": time_step,
+ "encoder_hidden_states": encoder_hidden_states,
+ "controlnet_cond": controlnet_cond,
+ "conditioning_scale": conditioning_scale,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "sample_size": 16,
+ "down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
+ "up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
+ "block_out_channels": (4, 8),
+ "cross_attention_dim": 8,
+ "transformer_layers_per_block": 1,
+ "num_attention_heads": 2,
+ "norm_num_groups": 4,
+ "upcast_attention": False,
+ "ctrl_block_out_channels": [2, 4],
+ "ctrl_num_attention_heads": 4,
+ "ctrl_max_norm_num_groups": 2,
+ "ctrl_conditioning_embedding_out_channels": (2, 2),
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def get_dummy_unet(self):
+ """For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
+ return UNet2DConditionModel(
+ block_out_channels=(4, 8),
+ layers_per_block=2,
+ sample_size=16,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=8,
+ norm_num_groups=4,
+ use_linear_projection=True,
+ )
+
+ def get_dummy_controlnet_from_unet(self, unet, **kwargs):
+ """For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
+ # size_ratio and conditioning_embedding_out_channels chosen to keep model small
+ return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs)
+
+ def test_from_unet(self):
+ unet = self.get_dummy_unet()
+ controlnet = self.get_dummy_controlnet_from_unet(unet)
+
+ model = UNetControlNetXSModel.from_unet(unet, controlnet)
+ model_state_dict = model.state_dict()
+
+ def assert_equal_weights(module, weight_dict_prefix):
+ for param_name, param_value in module.named_parameters():
+ assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value)
+
+ # # check unet
+ # everything expect down,mid,up blocks
+ modules_from_unet = [
+ "time_embedding",
+ "conv_in",
+ "conv_norm_out",
+ "conv_out",
+ ]
+ for p in modules_from_unet:
+ assert_equal_weights(getattr(unet, p), "base_" + p)
+ optional_modules_from_unet = [
+ "class_embedding",
+ "add_time_proj",
+ "add_embedding",
+ ]
+ for p in optional_modules_from_unet:
+ if hasattr(unet, p) and getattr(unet, p) is not None:
+ assert_equal_weights(getattr(unet, p), "base_" + p)
+ # down blocks
+ assert len(unet.down_blocks) == len(model.down_blocks)
+ for i, d in enumerate(unet.down_blocks):
+ assert_equal_weights(d.resnets, f"down_blocks.{i}.base_resnets")
+ if hasattr(d, "attentions"):
+ assert_equal_weights(d.attentions, f"down_blocks.{i}.base_attentions")
+ if hasattr(d, "downsamplers") and getattr(d, "downsamplers") is not None:
+ assert_equal_weights(d.downsamplers[0], f"down_blocks.{i}.base_downsamplers")
+ # mid block
+ assert_equal_weights(unet.mid_block, "mid_block.base_midblock")
+ # up blocks
+ assert len(unet.up_blocks) == len(model.up_blocks)
+ for i, u in enumerate(unet.up_blocks):
+ assert_equal_weights(u.resnets, f"up_blocks.{i}.resnets")
+ if hasattr(u, "attentions"):
+ assert_equal_weights(u.attentions, f"up_blocks.{i}.attentions")
+ if hasattr(u, "upsamplers") and getattr(u, "upsamplers") is not None:
+ assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers")
+
+ # # check controlnet
+ # everything expect down,mid,up blocks
+ modules_from_controlnet = {
+ "controlnet_cond_embedding": "controlnet_cond_embedding",
+ "conv_in": "ctrl_conv_in",
+ "control_to_base_for_conv_in": "control_to_base_for_conv_in",
+ }
+ optional_modules_from_controlnet = {"time_embedding": "ctrl_time_embedding"}
+ for name_in_controlnet, name_in_unetcnxs in modules_from_controlnet.items():
+ assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs)
+
+ for name_in_controlnet, name_in_unetcnxs in optional_modules_from_controlnet.items():
+ if hasattr(controlnet, name_in_controlnet) and getattr(controlnet, name_in_controlnet) is not None:
+ assert_equal_weights(getattr(controlnet, name_in_controlnet), name_in_unetcnxs)
+ # down blocks
+ assert len(controlnet.down_blocks) == len(model.down_blocks)
+ for i, d in enumerate(controlnet.down_blocks):
+ assert_equal_weights(d.resnets, f"down_blocks.{i}.ctrl_resnets")
+ assert_equal_weights(d.base_to_ctrl, f"down_blocks.{i}.base_to_ctrl")
+ assert_equal_weights(d.ctrl_to_base, f"down_blocks.{i}.ctrl_to_base")
+ if d.attentions is not None:
+ assert_equal_weights(d.attentions, f"down_blocks.{i}.ctrl_attentions")
+ if d.downsamplers is not None:
+ assert_equal_weights(d.downsamplers, f"down_blocks.{i}.ctrl_downsamplers")
+ # mid block
+ assert_equal_weights(controlnet.mid_block.base_to_ctrl, "mid_block.base_to_ctrl")
+ assert_equal_weights(controlnet.mid_block.midblock, "mid_block.ctrl_midblock")
+ assert_equal_weights(controlnet.mid_block.ctrl_to_base, "mid_block.ctrl_to_base")
+ # up blocks
+ assert len(controlnet.up_connections) == len(model.up_blocks)
+ for i, u in enumerate(controlnet.up_connections):
+ assert_equal_weights(u.ctrl_to_base, f"up_blocks.{i}.ctrl_to_base")
+
+ def test_freeze_unet(self):
+ def assert_frozen(module):
+ for p in module.parameters():
+ assert not p.requires_grad
+
+ def assert_unfrozen(module):
+ for p in module.parameters():
+ assert p.requires_grad
+
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+ model = UNetControlNetXSModel(**init_dict)
+ model.freeze_unet_params()
+
+ # # check unet
+ # everything expect down,mid,up blocks
+ modules_from_unet = [
+ model.base_time_embedding,
+ model.base_conv_in,
+ model.base_conv_norm_out,
+ model.base_conv_out,
+ ]
+ for m in modules_from_unet:
+ assert_frozen(m)
+
+ optional_modules_from_unet = [
+ model.base_add_time_proj,
+ model.base_add_embedding,
+ ]
+ for m in optional_modules_from_unet:
+ if m is not None:
+ assert_frozen(m)
+
+ # down blocks
+ for i, d in enumerate(model.down_blocks):
+ assert_frozen(d.base_resnets)
+ if isinstance(d.base_attentions, nn.ModuleList): # attentions can be list of Nones
+ assert_frozen(d.base_attentions)
+ if d.base_downsamplers is not None:
+ assert_frozen(d.base_downsamplers)
+
+ # mid block
+ assert_frozen(model.mid_block.base_midblock)
+
+ # up blocks
+ for i, u in enumerate(model.up_blocks):
+ assert_frozen(u.resnets)
+ if isinstance(u.attentions, nn.ModuleList): # attentions can be list of Nones
+ assert_frozen(u.attentions)
+ if u.upsamplers is not None:
+ assert_frozen(u.upsamplers)
+
+ # # check controlnet
+ # everything expect down,mid,up blocks
+ modules_from_controlnet = [
+ model.controlnet_cond_embedding,
+ model.ctrl_conv_in,
+ model.control_to_base_for_conv_in,
+ ]
+ optional_modules_from_controlnet = [model.ctrl_time_embedding]
+
+ for m in modules_from_controlnet:
+ assert_unfrozen(m)
+ for m in optional_modules_from_controlnet:
+ if m is not None:
+ assert_unfrozen(m)
+
+ # down blocks
+ for d in model.down_blocks:
+ assert_unfrozen(d.ctrl_resnets)
+ assert_unfrozen(d.base_to_ctrl)
+ assert_unfrozen(d.ctrl_to_base)
+ if isinstance(d.ctrl_attentions, nn.ModuleList): # attentions can be list of Nones
+ assert_unfrozen(d.ctrl_attentions)
+ if d.ctrl_downsamplers is not None:
+ assert_unfrozen(d.ctrl_downsamplers)
+ # mid block
+ assert_unfrozen(model.mid_block.base_to_ctrl)
+ assert_unfrozen(model.mid_block.ctrl_midblock)
+ assert_unfrozen(model.mid_block.ctrl_to_base)
+ # up blocks
+ for u in model.up_blocks:
+ assert_unfrozen(u.ctrl_to_base)
+
+ def test_gradient_checkpointing_is_applied(self):
+ model_class_copy = copy.copy(UNetControlNetXSModel)
+
+ modules_with_gc_enabled = {}
+
+ # now monkey patch the following function:
+ # def _set_gradient_checkpointing(self, module, value=False):
+ # if hasattr(module, "gradient_checkpointing"):
+ # module.gradient_checkpointing = value
+
+ def _set_gradient_checkpointing_new(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+ modules_with_gc_enabled[module.__class__.__name__] = True
+
+ model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
+
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+ model = model_class_copy(**init_dict)
+
+ model.enable_gradient_checkpointing()
+
+ EXPECTED_SET = {
+ "Transformer2DModel",
+ "UNetMidBlock2DCrossAttn",
+ "ControlNetXSCrossAttnDownBlock2D",
+ "ControlNetXSCrossAttnMidBlock2D",
+ "ControlNetXSCrossAttnUpBlock2D",
+ }
+
+ assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
+ assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
+
+ def test_forward_no_control(self):
+ unet = self.get_dummy_unet()
+ controlnet = self.get_dummy_controlnet_from_unet(unet)
+
+ model = UNetControlNetXSModel.from_unet(unet, controlnet)
+
+ unet = unet.to(torch_device)
+ model = model.to(torch_device)
+
+ input_ = self.dummy_input
+
+ control_specific_input = ["controlnet_cond", "conditioning_scale"]
+ input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input}
+
+ with torch.no_grad():
+ unet_output = unet(**input_for_unet).sample.cpu()
+ unet_controlnet_output = model(**input_, apply_control=False).sample.cpu()
+
+ assert np.abs(unet_output.flatten() - unet_controlnet_output.flatten()).max() < 3e-4
+
+ def test_time_embedding_mixing(self):
+ unet = self.get_dummy_unet()
+ controlnet = self.get_dummy_controlnet_from_unet(unet)
+ controlnet_mix_time = self.get_dummy_controlnet_from_unet(
+ unet, time_embedding_mix=0.5, learn_time_embedding=True
+ )
+
+ model = UNetControlNetXSModel.from_unet(unet, controlnet)
+ model_mix_time = UNetControlNetXSModel.from_unet(unet, controlnet_mix_time)
+
+ unet = unet.to(torch_device)
+ model = model.to(torch_device)
+ model_mix_time = model_mix_time.to(torch_device)
+
+ input_ = self.dummy_input
+
+ with torch.no_grad():
+ output = model(**input_).sample
+ output_mix_time = model_mix_time(**input_).sample
+
+ assert output.shape == output_mix_time.shape
+
+ def test_forward_with_norm_groups(self):
+ # UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
+ pass
diff --git a/tests/models/unets/test_models_unet_motion.py b/tests/models/unets/test_models_unet_motion.py
index 4f7aba015240..7d83b07c49fe 100644
--- a/tests/models/unets/test_models_unet_motion.py
+++ b/tests/models/unets/test_models_unet_motion.py
@@ -46,34 +46,35 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
def dummy_input(self):
batch_size = 4
num_channels = 4
- num_frames = 8
- sizes = (32, 32)
+ num_frames = 4
+ sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
- encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
+ encoder_hidden_states = floats_tensor((batch_size, 4, 16)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def input_shape(self):
- return (4, 8, 32, 32)
+ return (4, 4, 16, 16)
@property
def output_shape(self):
- return (4, 8, 32, 32)
+ return (4, 4, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
- "block_out_channels": (32, 64),
+ "block_out_channels": (16, 32),
+ "norm_num_groups": 16,
"down_block_types": ("CrossAttnDownBlockMotion", "DownBlockMotion"),
"up_block_types": ("UpBlockMotion", "CrossAttnUpBlockMotion"),
- "cross_attention_dim": 32,
- "num_attention_heads": 4,
+ "cross_attention_dim": 16,
+ "num_attention_heads": 2,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 1,
- "sample_size": 32,
+ "sample_size": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@@ -194,6 +195,7 @@ def _set_gradient_checkpointing_new(self, module, value=False):
def test_feed_forward_chunking(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict)
diff --git a/tests/models/unets/test_models_unet_spatiotemporal.py b/tests/models/unets/test_models_unet_spatiotemporal.py
index 935aa7f6fe4c..afdd3d127702 100644
--- a/tests/models/unets/test_models_unet_spatiotemporal.py
+++ b/tests/models/unets/test_models_unet_spatiotemporal.py
@@ -24,6 +24,7 @@
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
+ skip_mps,
torch_all_close,
torch_device,
)
@@ -36,6 +37,7 @@
enable_full_determinism()
+@skip_mps
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetSpatioTemporalConditionModel
main_input_name = "sample"
diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py
index cabfd29e0d32..bc847400adc3 100644
--- a/tests/pipelines/animatediff/test_animatediff_video2video.py
+++ b/tests/pipelines/animatediff/test_animatediff_video2video.py
@@ -269,6 +269,17 @@ def test_prompt_embeds(self):
inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device)
pipe(**inputs)
+ def test_latent_inputs(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+ pipe.to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["latents"] = torch.randn((1, 4, 1, 32, 32), device=torch_device)
+ inputs.pop("video")
+ pipe(**inputs)
+
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py
index ca8652a0b555..08fd361940a9 100644
--- a/tests/pipelines/audioldm2/test_audioldm2.py
+++ b/tests/pipelines/audioldm2/test_audioldm2.py
@@ -516,6 +516,20 @@ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0
}
return inputs
+ def get_inputs_tts(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
+ generator = torch.Generator(device=generator_device).manual_seed(seed)
+ latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
+ latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
+ inputs = {
+ "prompt": "A men saying",
+ "transcription": "hello my name is John",
+ "latents": latents,
+ "generator": generator,
+ "num_inference_steps": 3,
+ "guidance_scale": 2.5,
+ }
+ return inputs
+
def test_audioldm2(self):
audioldm_pipe = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2")
audioldm_pipe = audioldm_pipe.to(torch_device)
@@ -572,3 +586,22 @@ def test_audioldm2_large(self):
)
max_diff = np.abs(expected_slice - audio_slice).max()
assert max_diff < 1e-3
+
+ def test_audioldm2_tts(self):
+ audioldm_tts_pipe = AudioLDM2Pipeline.from_pretrained("anhnct/audioldm2_gigaspeech")
+ audioldm_tts_pipe = audioldm_tts_pipe.to(torch_device)
+ audioldm_tts_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_inputs_tts(torch_device)
+ audio = audioldm_tts_pipe(**inputs).audios[0]
+
+ assert audio.ndim == 1
+ assert len(audio) == 81952
+
+ # check the portion of the generated audio with the largest dynamic range (reduces flakiness)
+ audio_slice = audio[8825:8835]
+ expected_slice = np.array(
+ [-0.1829, -0.1461, 0.0759, -0.1493, -0.1396, 0.5783, 0.3001, -0.3038, -0.0639, -0.2244]
+ )
+ max_diff = np.abs(expected_slice - audio_slice).max()
+ assert max_diff < 1e-3
diff --git a/tests/pipelines/controlnet_xs/__init__.py b/tests/pipelines/controlnet_xs/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py
new file mode 100644
index 000000000000..42795807792b
--- /dev/null
+++ b/tests/pipelines/controlnet_xs/test_controlnetxs.py
@@ -0,0 +1,398 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import traceback
+import unittest
+
+import numpy as np
+import torch
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+
+from diffusers import (
+ AsymmetricAutoencoderKL,
+ AutoencoderKL,
+ AutoencoderTiny,
+ ConsistencyDecoderVAE,
+ ControlNetXSAdapter,
+ DDIMScheduler,
+ LCMScheduler,
+ StableDiffusionControlNetXSPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ load_image,
+ load_numpy,
+ require_python39_or_higher,
+ require_torch_2,
+ require_torch_gpu,
+ run_test_in_subprocess,
+ slow,
+ torch_device,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...models.autoencoders.test_models_vae import (
+ get_asym_autoencoder_kl_config,
+ get_autoencoder_kl_config,
+ get_autoencoder_tiny_config,
+ get_consistency_vae_config,
+)
+from ..pipeline_params import (
+ IMAGE_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_TO_IMAGE_BATCH_PARAMS,
+ TEXT_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_TO_IMAGE_PARAMS,
+)
+from ..test_pipelines_common import (
+ PipelineKarrasSchedulerTesterMixin,
+ PipelineLatentTesterMixin,
+ PipelineTesterMixin,
+ SDFunctionTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+def to_np(tensor):
+ if isinstance(tensor, torch.Tensor):
+ tensor = tensor.detach().cpu().numpy()
+
+ return tensor
+
+
+# Will be run via run_test_in_subprocess
+def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
+ error = None
+ try:
+ _ = in_queue.get(timeout=timeout)
+
+ controlnet = ControlNetXSAdapter.from_pretrained(
+ "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
+ )
+ pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-1-base",
+ controlnet=controlnet,
+ safety_checker=None,
+ torch_dtype=torch.float16,
+ )
+ pipe.to("cuda")
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe.unet.to(memory_format=torch.channels_last)
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+ prompt = "bird"
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
+ ).resize((512, 512))
+
+ output = pipe(prompt, image, num_inference_steps=10, generator=generator, output_type="np")
+ image = output.images[0]
+
+ assert image.shape == (512, 512, 3)
+
+ expected_image = load_numpy(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy"
+ )
+ expected_image = np.resize(expected_image, (512, 512, 3))
+
+ assert np.abs(expected_image - image).max() < 1.0
+
+ except Exception:
+ error = f"{traceback.format_exc()}"
+
+ results = {"error": error}
+ out_queue.put(results, timeout=timeout)
+ out_queue.join()
+
+
+class ControlNetXSPipelineFastTests(
+ PipelineLatentTesterMixin,
+ PipelineKarrasSchedulerTesterMixin,
+ PipelineTesterMixin,
+ SDFunctionTesterMixin,
+ unittest.TestCase,
+):
+ pipeline_class = StableDiffusionControlNetXSPipeline
+ params = TEXT_TO_IMAGE_PARAMS
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+
+ test_attention_slicing = False
+
+ def get_dummy_components(self, time_cond_proj_dim=None):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(4, 8),
+ layers_per_block=2,
+ sample_size=16,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ cross_attention_dim=8,
+ norm_num_groups=4,
+ time_cond_proj_dim=time_cond_proj_dim,
+ use_linear_projection=True,
+ )
+ torch.manual_seed(0)
+ controlnet = ControlNetXSAdapter.from_unet(
+ unet=unet,
+ size_ratio=1,
+ learn_time_embedding=True,
+ conditioning_embedding_out_channels=(2, 2),
+ )
+ torch.manual_seed(0)
+ scheduler = DDIMScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ clip_sample=False,
+ set_alpha_to_one=False,
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[4, 8],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ norm_num_groups=2,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=8,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "unet": unet,
+ "controlnet": controlnet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "safety_checker": None,
+ "feature_extractor": None,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ controlnet_embedder_scale_factor = 2
+ image = randn_tensor(
+ (1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor),
+ generator=generator,
+ device=torch.device(device),
+ )
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "output_type": "numpy",
+ "image": image,
+ }
+
+ return inputs
+
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_xformers_available(),
+ reason="XFormers attention is only available with CUDA and `xformers` installed",
+ )
+ def test_xformers_attention_forwardGenerator_pass(self):
+ self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
+
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(expected_max_diff=2e-3)
+
+ def test_controlnet_lcm(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+
+ components = self.get_dummy_components(time_cond_proj_dim=8)
+ sd_pipe = StableDiffusionControlNetXSPipeline(**components)
+ sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ output = sd_pipe(**inputs)
+ image = output.images
+
+ image_slice = image[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 16, 16, 3)
+ expected_slice = np.array([0.745, 0.753, 0.767, 0.543, 0.523, 0.502, 0.314, 0.521, 0.478])
+
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+
+ def test_to_dtype(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
+
+ pipe.to(dtype=torch.float16)
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
+
+ def test_multi_vae(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ block_out_channels = pipe.vae.config.block_out_channels
+ norm_num_groups = pipe.vae.config.norm_num_groups
+
+ vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
+ configs = [
+ get_autoencoder_kl_config(block_out_channels, norm_num_groups),
+ get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
+ get_consistency_vae_config(block_out_channels, norm_num_groups),
+ get_autoencoder_tiny_config(block_out_channels),
+ ]
+
+ out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
+
+ for vae_cls, config in zip(vae_classes, configs):
+ vae = vae_cls(**config)
+ vae = vae.to(torch_device)
+ components["vae"] = vae
+ vae_pipe = self.pipeline_class(**components)
+
+ # pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device.
+ # So we need to move the new pipe to device.
+ vae_pipe.to(torch_device)
+ vae_pipe.set_progress_bar_config(disable=None)
+
+ out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
+
+ assert out_vae_np.shape == out_np.shape
+
+ @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices")
+ def test_to_device(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe.to("cpu")
+ # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the device from pipe.components
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ self.assertTrue(all(device == "cpu" for device in model_devices))
+
+ output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
+ self.assertTrue(np.isnan(output_cpu).sum() == 0)
+
+ pipe.to("cuda")
+ model_devices = [
+ component.device.type for component in pipe.components.values() if hasattr(component, "device")
+ ]
+ self.assertTrue(all(device == "cuda" for device in model_devices))
+
+ output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
+ self.assertTrue(np.isnan(to_np(output_cuda)).sum() == 0)
+
+
+@slow
+@require_torch_gpu
+class ControlNetXSPipelineSlowTests(unittest.TestCase):
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_canny(self):
+ controlnet = ControlNetXSAdapter.from_pretrained(
+ "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
+ )
+ pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
+ )
+ pipe.enable_model_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+ prompt = "bird"
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
+ )
+
+ output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
+
+ image = output.images[0]
+
+ assert image.shape == (768, 512, 3)
+
+ original_image = image[-3:, -3:, -1].flatten()
+ expected_image = np.array([0.1963, 0.229, 0.2659, 0.2109, 0.2332, 0.2827, 0.2534, 0.2422, 0.2808])
+ assert np.allclose(original_image, expected_image, atol=1e-04)
+
+ def test_depth(self):
+ controlnet = ControlNetXSAdapter.from_pretrained(
+ "UmerHA/Testing-ConrolNetXS-SD2.1-depth", torch_dtype=torch.float16
+ )
+ pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
+ )
+ pipe.enable_model_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+ prompt = "Stormtrooper's lecture"
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
+ )
+
+ output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
+
+ image = output.images[0]
+
+ assert image.shape == (512, 512, 3)
+
+ original_image = image[-3:, -3:, -1].flatten()
+ expected_image = np.array([0.4844, 0.4937, 0.4956, 0.4663, 0.5039, 0.5044, 0.4565, 0.4883, 0.4941])
+ assert np.allclose(original_image, expected_image, atol=1e-04)
+
+ @require_python39_or_higher
+ @require_torch_2
+ def test_stable_diffusion_compile(self):
+ run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None)
diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py
new file mode 100644
index 000000000000..ee0d15ec3472
--- /dev/null
+++ b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py
@@ -0,0 +1,425 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import unittest
+
+import numpy as np
+import torch
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
+
+from diffusers import (
+ AsymmetricAutoencoderKL,
+ AutoencoderKL,
+ AutoencoderTiny,
+ ConsistencyDecoderVAE,
+ ControlNetXSAdapter,
+ EulerDiscreteScheduler,
+ StableDiffusionXLControlNetXSPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device
+from diffusers.utils.torch_utils import randn_tensor
+
+from ...models.autoencoders.test_models_vae import (
+ get_asym_autoencoder_kl_config,
+ get_autoencoder_kl_config,
+ get_autoencoder_tiny_config,
+ get_consistency_vae_config,
+)
+from ..pipeline_params import (
+ IMAGE_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_TO_IMAGE_BATCH_PARAMS,
+ TEXT_TO_IMAGE_IMAGE_PARAMS,
+ TEXT_TO_IMAGE_PARAMS,
+)
+from ..test_pipelines_common import (
+ PipelineKarrasSchedulerTesterMixin,
+ PipelineLatentTesterMixin,
+ PipelineTesterMixin,
+ SDXLOptionalComponentsTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class StableDiffusionXLControlNetXSPipelineFastTests(
+ PipelineLatentTesterMixin,
+ PipelineKarrasSchedulerTesterMixin,
+ PipelineTesterMixin,
+ SDXLOptionalComponentsTesterMixin,
+ unittest.TestCase,
+):
+ pipeline_class = StableDiffusionXLControlNetXSPipeline
+ params = TEXT_TO_IMAGE_PARAMS
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+
+ test_attention_slicing = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ unet = UNet2DConditionModel(
+ block_out_channels=(4, 8),
+ layers_per_block=2,
+ sample_size=16,
+ in_channels=4,
+ out_channels=4,
+ down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
+ up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
+ use_linear_projection=True,
+ norm_num_groups=4,
+ # SD2-specific config below
+ attention_head_dim=(2, 4),
+ addition_embed_type="text_time",
+ addition_time_embed_dim=8,
+ transformer_layers_per_block=(1, 2),
+ projection_class_embeddings_input_dim=56, # 6 * 8 (addition_time_embed_dim) + 8 (cross_attention_dim)
+ cross_attention_dim=8,
+ )
+ torch.manual_seed(0)
+ controlnet = ControlNetXSAdapter.from_unet(
+ unet=unet,
+ size_ratio=0.5,
+ learn_time_embedding=True,
+ conditioning_embedding_out_channels=(2, 2),
+ )
+ torch.manual_seed(0)
+ scheduler = EulerDiscreteScheduler(
+ beta_start=0.00085,
+ beta_end=0.012,
+ steps_offset=1,
+ beta_schedule="scaled_linear",
+ timestep_spacing="leading",
+ )
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ block_out_channels=[4, 8],
+ in_channels=3,
+ out_channels=3,
+ down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
+ up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
+ latent_channels=4,
+ norm_num_groups=2,
+ )
+ torch.manual_seed(0)
+ text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=4,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ # SD2-specific config below
+ hidden_act="gelu",
+ projection_dim=8,
+ )
+ text_encoder = CLIPTextModel(text_encoder_config)
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ components = {
+ "unet": unet,
+ "controlnet": controlnet,
+ "scheduler": scheduler,
+ "vae": vae,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer_2": tokenizer_2,
+ "feature_extractor": None,
+ }
+ return components
+
+ # copied from test_controlnet_sdxl.py
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ controlnet_embedder_scale_factor = 2
+ image = randn_tensor(
+ (1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor),
+ generator=generator,
+ device=torch.device(device),
+ )
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "output_type": "np",
+ "image": image,
+ }
+
+ return inputs
+
+ # copied from test_controlnet_sdxl.py
+ def test_attention_slicing_forward_pass(self):
+ return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
+
+ # copied from test_controlnet_sdxl.py
+ @unittest.skipIf(
+ torch_device != "cuda" or not is_xformers_available(),
+ reason="XFormers attention is only available with CUDA and `xformers` installed",
+ )
+ def test_xformers_attention_forwardGenerator_pass(self):
+ self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
+
+ # copied from test_controlnet_sdxl.py
+ def test_inference_batch_single_identical(self):
+ self._test_inference_batch_single_identical(expected_max_diff=2e-3)
+
+ # copied from test_controlnet_sdxl.py
+ @require_torch_gpu
+ def test_stable_diffusion_xl_offloads(self):
+ pipes = []
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components).to(torch_device)
+ pipes.append(sd_pipe)
+
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components)
+ sd_pipe.enable_model_cpu_offload()
+ pipes.append(sd_pipe)
+
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components)
+ sd_pipe.enable_sequential_cpu_offload()
+ pipes.append(sd_pipe)
+
+ image_slices = []
+ for pipe in pipes:
+ pipe.unet.set_default_attn_processor()
+
+ inputs = self.get_dummy_inputs(torch_device)
+ image = pipe(**inputs).images
+
+ image_slices.append(image[0, -3:, -3:, -1].flatten())
+
+ assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
+ assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
+
+ # copied from test_controlnet_sdxl.py
+ def test_stable_diffusion_xl_multi_prompts(self):
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components).to(torch_device)
+
+ # forward with single prompt
+ inputs = self.get_dummy_inputs(torch_device)
+ output = sd_pipe(**inputs)
+ image_slice_1 = output.images[0, -3:, -3:, -1]
+
+ # forward with same prompt duplicated
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt_2"] = inputs["prompt"]
+ output = sd_pipe(**inputs)
+ image_slice_2 = output.images[0, -3:, -3:, -1]
+
+ # ensure the results are equal
+ assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
+
+ # forward with different prompt
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt_2"] = "different prompt"
+ output = sd_pipe(**inputs)
+ image_slice_3 = output.images[0, -3:, -3:, -1]
+
+ # ensure the results are not equal
+ assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
+
+ # manually set a negative_prompt
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["negative_prompt"] = "negative prompt"
+ output = sd_pipe(**inputs)
+ image_slice_1 = output.images[0, -3:, -3:, -1]
+
+ # forward with same negative_prompt duplicated
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["negative_prompt"] = "negative prompt"
+ inputs["negative_prompt_2"] = inputs["negative_prompt"]
+ output = sd_pipe(**inputs)
+ image_slice_2 = output.images[0, -3:, -3:, -1]
+
+ # ensure the results are equal
+ assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
+
+ # forward with different negative_prompt
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["negative_prompt"] = "negative prompt"
+ inputs["negative_prompt_2"] = "different negative prompt"
+ output = sd_pipe(**inputs)
+ image_slice_3 = output.images[0, -3:, -3:, -1]
+
+ # ensure the results are not equal
+ assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
+
+ # copied from test_stable_diffusion_xl.py
+ def test_stable_diffusion_xl_prompt_embeds(self):
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe = sd_pipe.to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=None)
+
+ # forward without prompt embeds
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt"] = 2 * [inputs["prompt"]]
+ inputs["num_images_per_prompt"] = 2
+
+ output = sd_pipe(**inputs)
+ image_slice_1 = output.images[0, -3:, -3:, -1]
+
+ # forward with prompt embeds
+ inputs = self.get_dummy_inputs(torch_device)
+ prompt = 2 * [inputs.pop("prompt")]
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = sd_pipe.encode_prompt(prompt)
+
+ output = sd_pipe(
+ **inputs,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ )
+ image_slice_2 = output.images[0, -3:, -3:, -1]
+
+ # make sure that it's equal
+ assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1.1e-4
+
+ # copied from test_stable_diffusion_xl.py
+ def test_save_load_optional_components(self):
+ self._test_save_load_optional_components()
+
+ # copied from test_controlnetxs.py
+ def test_to_dtype(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.set_progress_bar_config(disable=None)
+
+ # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
+
+ pipe.to(dtype=torch.float16)
+ model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
+ self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
+
+ def test_multi_vae(self):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ block_out_channels = pipe.vae.config.block_out_channels
+ norm_num_groups = pipe.vae.config.norm_num_groups
+
+ vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
+ configs = [
+ get_autoencoder_kl_config(block_out_channels, norm_num_groups),
+ get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
+ get_consistency_vae_config(block_out_channels, norm_num_groups),
+ get_autoencoder_tiny_config(block_out_channels),
+ ]
+
+ out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
+
+ for vae_cls, config in zip(vae_classes, configs):
+ vae = vae_cls(**config)
+ vae = vae.to(torch_device)
+ components["vae"] = vae
+ vae_pipe = self.pipeline_class(**components)
+
+ # pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device.
+ # So we need to move the new pipe to device.
+ vae_pipe.to(torch_device)
+ vae_pipe.set_progress_bar_config(disable=None)
+
+ out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
+
+ assert out_vae_np.shape == out_np.shape
+
+
+@slow
+@require_torch_gpu
+class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase):
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_canny(self):
+ controlnet = ControlNetXSAdapter.from_pretrained(
+ "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16
+ )
+ pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
+ )
+ pipe.enable_sequential_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+ prompt = "bird"
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
+ )
+
+ images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
+
+ assert images[0].shape == (768, 512, 3)
+
+ original_image = images[0, -3:, -3:, -1].flatten()
+ expected_image = np.array([0.3202, 0.3151, 0.3328, 0.3172, 0.337, 0.3381, 0.3378, 0.3389, 0.3224])
+ assert np.allclose(original_image, expected_image, atol=1e-04)
+
+ def test_depth(self):
+ controlnet = ControlNetXSAdapter.from_pretrained(
+ "UmerHA/Testing-ConrolNetXS-SDXL-depth", torch_dtype=torch.float16
+ )
+ pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
+ )
+ pipe.enable_sequential_cpu_offload()
+ pipe.set_progress_bar_config(disable=None)
+
+ generator = torch.Generator(device="cpu").manual_seed(0)
+ prompt = "Stormtrooper's lecture"
+ image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
+ )
+
+ images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
+
+ assert images[0].shape == (512, 512, 3)
+
+ original_image = images[0, -3:, -3:, -1].flatten()
+ expected_image = np.array([0.5448, 0.5437, 0.5426, 0.5543, 0.553, 0.5475, 0.5595, 0.5602, 0.5529])
+ assert np.allclose(original_image, expected_image, atol=1e-04)
diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
index bed3ca82b8d2..ef70baa05f19 100644
--- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
+++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py
@@ -37,6 +37,7 @@
from diffusers.utils.testing_utils import (
enable_full_determinism,
is_flaky,
+ load_pt,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
@@ -306,6 +307,35 @@ def test_multi(self):
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
+ def test_text_to_image_face_id(self):
+ pipeline = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", safety_checker=None, torch_dtype=self.dtype
+ )
+ pipeline.to(torch_device)
+ pipeline.load_ip_adapter(
+ "h94/IP-Adapter-FaceID",
+ subfolder=None,
+ weight_name="ip-adapter-faceid_sd15.bin",
+ image_encoder_folder=None,
+ )
+ pipeline.set_ip_adapter_scale(0.7)
+
+ inputs = self.get_dummy_inputs()
+ id_embeds = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt")[
+ 0
+ ]
+ id_embeds = id_embeds.reshape((2, 1, 1, 512))
+ inputs["ip_adapter_image_embeds"] = [id_embeds]
+ inputs["ip_adapter_image"] = None
+ images = pipeline(**inputs).images
+ image_slice = images[0, :3, :3, -1].flatten()
+
+ expected_slice = np.array(
+ [0.32714844, 0.3239746, 0.3466797, 0.31835938, 0.30004883, 0.3251953, 0.3215332, 0.3552246, 0.3251953]
+ )
+ max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
+ assert max_diff < 5e-4
+
@slow
@require_torch_gpu
@@ -544,3 +574,33 @@ def test_ip_adapter_multiple_masks(self):
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
+
+ def test_ip_adapter_multiple_masks_one_adapter(self):
+ image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
+ pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ image_encoder=image_encoder,
+ torch_dtype=self.dtype,
+ )
+ pipeline.enable_model_cpu_offload()
+ pipeline.load_ip_adapter(
+ "h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"]
+ )
+ pipeline.set_ip_adapter_scale([[0.7, 0.7]])
+
+ inputs = self.get_dummy_inputs(for_masks=True)
+ masks = inputs["cross_attention_kwargs"]["ip_adapter_masks"]
+ processor = IPAdapterMaskProcessor()
+ masks = processor.preprocess(masks)
+ masks = masks.reshape(1, masks.shape[0], masks.shape[2], masks.shape[3])
+ inputs["cross_attention_kwargs"]["ip_adapter_masks"] = [masks]
+ ip_images = inputs["ip_adapter_image"]
+ inputs["ip_adapter_image"] = [[image[0] for image in ip_images]]
+ images = pipeline(**inputs).images
+ image_slice = images[0, :3, :3, -1].flatten()
+ expected_slice = np.array(
+ [0.79474676, 0.7977683, 0.8013954, 0.7988008, 0.7970615, 0.8029355, 0.80614823, 0.8050743, 0.80627424]
+ )
+
+ max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
+ assert max_diff < 5e-4
diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
index 576fe24bfbfa..9b9a8ef65572 100644
--- a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
+++ b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py
@@ -124,7 +124,7 @@ def test_inference_superresolution(self):
)
init_image = init_image.resize((64, 64), resample=PIL_INTERPOLATION["lanczos"])
- ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto")
+ ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution")
ldm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
index 9a71cc462b10..145e0012f8e9 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py
@@ -50,10 +50,13 @@
load_numpy,
nightly,
numpy_cosine_similarity_distance,
+ require_accelerate_version_greater,
require_python39_or_higher,
require_torch_2,
require_torch_gpu,
+ require_torch_multi_gpu,
run_test_in_subprocess,
+ skip_mps,
slow,
torch_device,
)
@@ -123,6 +126,8 @@ class StableDiffusionPipelineFastTests(
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
def get_dummy_components(self, time_cond_proj_dim=None):
+ cross_attention_dim = 8
+
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(4, 8),
@@ -133,7 +138,7 @@ def get_dummy_components(self, time_cond_proj_dim=None):
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
+ cross_attention_dim=cross_attention_dim,
norm_num_groups=2,
)
scheduler = DDIMScheduler(
@@ -157,11 +162,11 @@ def get_dummy_components(self, time_cond_proj_dim=None):
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
- hidden_size=32,
- intermediate_size=64,
+ hidden_size=cross_attention_dim,
+ intermediate_size=16,
layer_norm_eps=1e-05,
- num_attention_heads=8,
- num_hidden_layers=3,
+ num_attention_heads=2,
+ num_hidden_layers=2,
pad_token_id=1,
vocab_size=1000,
)
@@ -209,7 +214,7 @@ def test_stable_diffusion_ddim(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3203, 0.4555, 0.4711, 0.3505, 0.3973, 0.4650, 0.5137, 0.3392, 0.4045])
+ expected_slice = np.array([0.1763, 0.4776, 0.4986, 0.2566, 0.3802, 0.4596, 0.5363, 0.3277, 0.3949])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -229,7 +234,7 @@ def test_stable_diffusion_lcm(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3454, 0.5349, 0.5185, 0.2808, 0.4509, 0.4612, 0.4655, 0.3601, 0.4315])
+ expected_slice = np.array([0.2368, 0.4900, 0.5019, 0.2723, 0.4473, 0.4578, 0.4551, 0.3532, 0.4133])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -251,7 +256,7 @@ def test_stable_diffusion_lcm_custom_timesteps(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3454, 0.5349, 0.5185, 0.2808, 0.4509, 0.4612, 0.4655, 0.3601, 0.4315])
+ expected_slice = np.array([0.2368, 0.4900, 0.5019, 0.2723, 0.4473, 0.4578, 0.4551, 0.3532, 0.4133])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -370,12 +375,6 @@ def test_stable_diffusion_prompt_embeds_with_plain_negative_prompt_list(self):
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
- def test_ip_adapter_single(self):
- expected_pipe_slice = None
- if torch_device == "cpu":
- expected_pipe_slice = np.array([0.3203, 0.4555, 0.4711, 0.3505, 0.3973, 0.4650, 0.5137, 0.3392, 0.4045])
- return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
-
def test_stable_diffusion_ddim_factor_8(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -391,7 +390,7 @@ def test_stable_diffusion_ddim_factor_8(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 136, 136, 3)
- expected_slice = np.array([0.4346, 0.5621, 0.5016, 0.3926, 0.4533, 0.4134, 0.5625, 0.5632, 0.5265])
+ expected_slice = np.array([0.4720, 0.5426, 0.5160, 0.3961, 0.4696, 0.4296, 0.5738, 0.5888, 0.5481])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -409,7 +408,7 @@ def test_stable_diffusion_pndm(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3411, 0.5032, 0.4704, 0.3135, 0.4323, 0.4740, 0.5150, 0.3498, 0.4022])
+ expected_slice = np.array([0.1941, 0.4748, 0.4880, 0.2222, 0.4221, 0.4545, 0.5604, 0.3488, 0.3902])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -449,7 +448,7 @@ def test_stable_diffusion_k_lms(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3149, 0.5246, 0.4796, 0.3218, 0.4469, 0.4729, 0.5151, 0.3597, 0.3954])
+ expected_slice = np.array([0.2681, 0.4785, 0.4857, 0.2426, 0.4473, 0.4481, 0.5610, 0.3676, 0.3855])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -468,7 +467,7 @@ def test_stable_diffusion_k_euler_ancestral(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3151, 0.5243, 0.4794, 0.3217, 0.4468, 0.4728, 0.5152, 0.3598, 0.3954])
+ expected_slice = np.array([0.2682, 0.4782, 0.4855, 0.2424, 0.4472, 0.4479, 0.5612, 0.3676, 0.3854])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -487,7 +486,7 @@ def test_stable_diffusion_k_euler(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3149, 0.5246, 0.4796, 0.3218, 0.4469, 0.4729, 0.5151, 0.3597, 0.3954])
+ expected_slice = np.array([0.2681, 0.4785, 0.4857, 0.2426, 0.4473, 0.4481, 0.5610, 0.3676, 0.3855])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -559,7 +558,7 @@ def test_stable_diffusion_negative_prompt(self):
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.3458, 0.5120, 0.4800, 0.3116, 0.4348, 0.4802, 0.5237, 0.3467, 0.3991])
+ expected_slice = np.array([0.1907, 0.4709, 0.4858, 0.2224, 0.4223, 0.4539, 0.5606, 0.3489, 0.3900])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@@ -639,6 +638,8 @@ def test_attention_slicing_forward_pass(self):
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
+ # MPS currently doesn't support ComplexFloats, which are required for freeU - see https://github.com/huggingface/diffusers/issues/7569.
+ @skip_mps
def test_freeu_enabled(self):
components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
@@ -1439,3 +1440,121 @@ def test_stable_diffusion_euler(self):
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 1e-3
+
+
+# (sayakpaul): This test suite was run in the DGX with two GPUs (1, 2).
+@slow
+@require_torch_multi_gpu
+@require_accelerate_version_greater("0.27.0")
+class StableDiffusionPipelineDeviceMapTests(unittest.TestCase):
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_inputs(self, generator_device="cpu", seed=0):
+ generator = torch.Generator(device=generator_device).manual_seed(seed)
+ inputs = {
+ "prompt": "a photograph of an astronaut riding a horse",
+ "generator": generator,
+ "num_inference_steps": 50,
+ "guidance_scale": 7.5,
+ "output_type": "np",
+ }
+ return inputs
+
+ def get_pipeline_output_without_device_map(self):
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
+ ).to(torch_device)
+ sd_pipe.set_progress_bar_config(disable=True)
+ inputs = self.get_inputs()
+ no_device_map_image = sd_pipe(**inputs).images
+
+ del sd_pipe
+
+ return no_device_map_image
+
+ def test_forward_pass_balanced_device_map(self):
+ no_device_map_image = self.get_pipeline_output_without_device_map()
+
+ sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
+ )
+ sd_pipe_with_device_map.set_progress_bar_config(disable=True)
+ inputs = self.get_inputs()
+ device_map_image = sd_pipe_with_device_map(**inputs).images
+
+ max_diff = np.abs(device_map_image - no_device_map_image).max()
+ assert max_diff < 1e-3
+
+ def test_components_put_in_right_devices(self):
+ sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
+ )
+
+ assert len(set(sd_pipe_with_device_map.hf_device_map.values())) >= 2
+
+ def test_max_memory(self):
+ no_device_map_image = self.get_pipeline_output_without_device_map()
+
+ sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5",
+ device_map="balanced",
+ max_memory={0: "1GB", 1: "1GB"},
+ torch_dtype=torch.float16,
+ )
+ sd_pipe_with_device_map.set_progress_bar_config(disable=True)
+ inputs = self.get_inputs()
+ device_map_image = sd_pipe_with_device_map(**inputs).images
+
+ max_diff = np.abs(device_map_image - no_device_map_image).max()
+ assert max_diff < 1e-3
+
+ def test_reset_device_map(self):
+ sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
+ )
+ sd_pipe_with_device_map.reset_device_map()
+
+ assert sd_pipe_with_device_map.hf_device_map is None
+
+ for name, component in sd_pipe_with_device_map.components.items():
+ if isinstance(component, torch.nn.Module):
+ assert component.device.type == "cpu"
+
+ def test_reset_device_map_to(self):
+ sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
+ )
+ sd_pipe_with_device_map.reset_device_map()
+
+ assert sd_pipe_with_device_map.hf_device_map is None
+
+ # Make sure `to()` can be used and the pipeline can be called.
+ pipe = sd_pipe_with_device_map.to("cuda")
+ _ = pipe("hello", num_inference_steps=2)
+
+ def test_reset_device_map_enable_model_cpu_offload(self):
+ sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
+ )
+ sd_pipe_with_device_map.reset_device_map()
+
+ assert sd_pipe_with_device_map.hf_device_map is None
+
+ # Make sure `enable_model_cpu_offload()` can be used and the pipeline can be called.
+ sd_pipe_with_device_map.enable_model_cpu_offload()
+ _ = sd_pipe_with_device_map("hello", num_inference_steps=2)
+
+ def test_reset_device_map_enable_sequential_cpu_offload(self):
+ sd_pipe_with_device_map = StableDiffusionPipeline.from_pretrained(
+ "runwayml/stable-diffusion-v1-5", device_map="balanced", torch_dtype=torch.float16
+ )
+ sd_pipe_with_device_map.reset_device_map()
+
+ assert sd_pipe_with_device_map.hf_device_map is None
+
+ # Make sure `enable_sequential_cpu_offload()` can be used and the pipeline can be called.
+ sd_pipe_with_device_map.enable_sequential_cpu_offload()
+ _ = sd_pipe_with_device_map("hello", num_inference_steps=2)
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index f0e6818bfc2b..b5dde78f11b2 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -32,6 +32,7 @@
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import IPAdapterMixin
from diffusers.models.attention_processor import AttnProcessor
+from diffusers.models.controlnet_xs import UNetControlNetXSModel
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
from diffusers.models.unets.unet_i2vgen_xl import I2VGenXLUNet
from diffusers.models.unets.unet_motion_model import UNetMotionModel
@@ -39,7 +40,7 @@
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
-from diffusers.utils.testing_utils import CaptureLogger, require_torch, torch_device
+from diffusers.utils.testing_utils import CaptureLogger, require_torch, skip_mps, torch_device
from ..models.autoencoders.test_models_vae import (
get_asym_autoencoder_kl_config,
@@ -47,7 +48,10 @@
get_autoencoder_tiny_config,
get_consistency_vae_config,
)
-from ..models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
+from ..models.unets.test_models_unet_2d_condition import (
+ create_ip_adapter_faceid_state_dict,
+ create_ip_adapter_state_dict,
+)
from ..others.test_utils import TOKEN, USER, is_staging_test
@@ -125,6 +129,8 @@ def test_vae_tiling(self):
zeros = torch.zeros(shape).to(torch_device)
pipe.vae.decode(zeros)
+ # MPS currently doesn't support ComplexFloats, which are required for freeU - see https://github.com/huggingface/diffusers/issues/7569.
+ @skip_mps
def test_freeu_enabled(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -236,6 +242,14 @@ def test_pipeline_signature(self):
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((2, 1, cross_attention_dim), device=torch_device)
+ def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32):
+ return torch.randn((2, 1, 1, cross_attention_dim), device=torch_device)
+
+ def _get_dummy_masks(self, input_size: int = 64):
+ _masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
+ _masks[0, :, :, : int(input_size / 2)] = 1
+ return _masks
+
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
parameters = inspect.signature(self.pipeline_class.__call__).parameters
if "image" in parameters.keys() and "strength" in parameters.keys():
@@ -363,6 +377,91 @@ def test_ip_adapter_cfg(self, expected_max_diff: float = 1e-4):
assert out_cfg.shape == out_no_cfg.shape
+ def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components).to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
+ sample_size = pipe.unet.config.get("sample_size", 32)
+ block_out_channels = pipe.vae.config.get("block_out_channels", [128, 256, 512, 512])
+ input_size = sample_size * (2 ** (len(block_out_channels) - 1))
+
+ # forward pass without ip adapter
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ output_without_adapter = pipe(**inputs)[0]
+ output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten()
+
+ adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
+ pipe.unet._load_ip_adapter_weights(adapter_state_dict)
+
+ # forward pass with single ip adapter and masks, but scale=0 which should have no effect
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
+ inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]}
+ pipe.set_ip_adapter_scale(0.0)
+ output_without_adapter_scale = pipe(**inputs)[0]
+ output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ # forward pass with single ip adapter and masks, but with scale of adapter weights
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
+ inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]}
+ pipe.set_ip_adapter_scale(42.0)
+ output_with_adapter_scale = pipe(**inputs)[0]
+ output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
+ max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
+
+ self.assertLess(
+ max_diff_without_adapter_scale,
+ expected_max_diff,
+ "Output without ip-adapter must be same as normal inference",
+ )
+ self.assertGreater(
+ max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
+ )
+
+ def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components).to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
+
+ # forward pass without ip adapter
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ output_without_adapter = pipe(**inputs)[0]
+ output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten()
+
+ adapter_state_dict = create_ip_adapter_faceid_state_dict(pipe.unet)
+ pipe.unet._load_ip_adapter_weights(adapter_state_dict)
+
+ # forward pass with single ip adapter, but scale=0 which should have no effect
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
+ pipe.set_ip_adapter_scale(0.0)
+ output_without_adapter_scale = pipe(**inputs)[0]
+ output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ # forward pass with single ip adapter, but with scale of adapter weights
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)]
+ pipe.set_ip_adapter_scale(42.0)
+ output_with_adapter_scale = pipe(**inputs)[0]
+ output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
+
+ max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
+ max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
+
+ self.assertLess(
+ max_diff_without_adapter_scale,
+ expected_max_diff,
+ "Output without ip-adapter must be same as normal inference",
+ )
+ self.assertGreater(
+ max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
+ )
+
class PipelineLatentTesterMixin:
"""
@@ -1633,7 +1732,10 @@ def test_StableDiffusionMixin_component(self):
self.assertTrue(hasattr(pipe, "vae") and isinstance(pipe.vae, (AutoencoderKL, AutoencoderTiny)))
self.assertTrue(
hasattr(pipe, "unet")
- and isinstance(pipe.unet, (UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel))
+ and isinstance(
+ pipe.unet,
+ (UNet2DConditionModel, UNet3DConditionModel, I2VGenXLUNet, UNetMotionModel, UNetControlNetXSModel),
+ )
)
diff --git a/tests/schedulers/test_scheduler_unipc.py b/tests/schedulers/test_scheduler_unipc.py
index fa34ef75b52c..5eb4d5ceef01 100644
--- a/tests/schedulers/test_scheduler_unipc.py
+++ b/tests/schedulers/test_scheduler_unipc.py
@@ -180,6 +180,10 @@ def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
+ def test_rescale_betas_zero_snr(self):
+ for rescale_betas_zero_snr in [True, False]:
+ self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr)
+
def test_solver_order_and_type(self):
for solver_type in ["bh1", "bh2"]:
for order in [1, 2, 3]:
@@ -229,20 +233,29 @@ def test_full_loop_with_karras_and_v_prediction(self):
assert abs(result_mean.item() - 0.1966) < 1e-3
def test_fp16_support(self):
- scheduler_class = self.scheduler_classes[0]
- scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
- scheduler = scheduler_class(**scheduler_config)
+ for order in [1, 2, 3]:
+ for solver_type in ["bh1", "bh2"]:
+ for prediction_type in ["epsilon", "sample", "v_prediction"]:
+ scheduler_class = self.scheduler_classes[0]
+ scheduler_config = self.get_scheduler_config(
+ thresholding=True,
+ dynamic_thresholding_ratio=0,
+ prediction_type=prediction_type,
+ solver_order=order,
+ solver_type=solver_type,
+ )
+ scheduler = scheduler_class(**scheduler_config)
- num_inference_steps = 10
- model = self.dummy_model()
- sample = self.dummy_sample_deter.half()
- scheduler.set_timesteps(num_inference_steps)
+ num_inference_steps = 10
+ model = self.dummy_model()
+ sample = self.dummy_sample_deter.half()
+ scheduler.set_timesteps(num_inference_steps)
- for i, t in enumerate(scheduler.timesteps):
- residual = model(sample, t)
- sample = scheduler.step(residual, t, sample).prev_sample
+ for i, t in enumerate(scheduler.timesteps):
+ residual = model(sample, t)
+ sample = scheduler.step(residual, t, sample).prev_sample
- assert sample.dtype == torch.float16
+ assert sample.dtype == torch.float16
def test_full_loop_with_noise(self):
scheduler_class = self.scheduler_classes[0]
diff --git a/utils/update_metadata.py b/utils/update_metadata.py
new file mode 100644
index 000000000000..840e4be78423
--- /dev/null
+++ b/utils/update_metadata.py
@@ -0,0 +1,106 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Utility that updates the metadata of the Diffusers library in the repository `huggingface/diffusers-metadata`.
+
+Usage for an update (as used by the GitHub action `update_metadata`):
+
+```bash
+python utils/update_metadata.py
+```
+
+Script modified from:
+https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py
+"""
+import argparse
+import os
+import tempfile
+
+import pandas as pd
+from datasets import Dataset
+from huggingface_hub import upload_folder
+
+from diffusers.pipelines.auto_pipeline import (
+ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
+ AUTO_INPAINT_PIPELINES_MAPPING,
+ AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
+)
+
+
+def get_supported_pipeline_table() -> dict:
+ """
+ Generates a dictionary containing the supported auto classes for each pipeline type,
+ using the content of the auto modules.
+ """
+ # All supported pipelines for automatic mapping.
+ all_supported_pipeline_classes = [
+ (class_name.__name__, "text-to-image", "AutoPipelineForText2Image")
+ for _, class_name in AUTO_TEXT2IMAGE_PIPELINES_MAPPING.items()
+ ]
+ all_supported_pipeline_classes += [
+ (class_name.__name__, "image-to-image", "AutoPipelineForImage2Image")
+ for _, class_name in AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.items()
+ ]
+ all_supported_pipeline_classes += [
+ (class_name.__name__, "image-to-image", "AutoPipelineForInpainting")
+ for _, class_name in AUTO_INPAINT_PIPELINES_MAPPING.items()
+ ]
+ all_supported_pipeline_classes.sort(key=lambda x: x[0])
+ all_supported_pipeline_classes = list(set(all_supported_pipeline_classes))
+
+ data = {}
+ data["pipeline_class"] = [sample[0] for sample in all_supported_pipeline_classes]
+ data["pipeline_tag"] = [sample[1] for sample in all_supported_pipeline_classes]
+ data["auto_class"] = [sample[2] for sample in all_supported_pipeline_classes]
+
+ return data
+
+
+def update_metadata(commit_sha: str):
+ """
+ Update the metadata for the Diffusers repo in `huggingface/diffusers-metadata`.
+
+ Args:
+ commit_sha (`str`): The commit SHA on Diffusers corresponding to this update.
+ """
+ pipelines_table = get_supported_pipeline_table()
+ pipelines_table = pd.DataFrame(pipelines_table)
+ pipelines_dataset = Dataset.from_pandas(pipelines_table)
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ pipelines_dataset.to_json(os.path.join(tmp_dir, "pipeline_tags.json"))
+
+ if commit_sha is not None:
+ commit_message = (
+ f"Update with commit {commit_sha}\n\nSee: "
+ f"https://github.com/huggingface/diffusers/commit/{commit_sha}"
+ )
+ else:
+ commit_message = "Update"
+
+ upload_folder(
+ repo_id="huggingface/diffusers-metadata",
+ folder_path=tmp_dir,
+ repo_type="dataset",
+ commit_message=commit_message,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--commit_sha", default=None, type=str, help="The sha of the commit going with this update.")
+ args = parser.parse_args()
+
+ update_metadata(args.commit_sha)