diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml
index 00c9cbed0636..d638d1435dfc 100644
--- a/.github/workflows/benchmark.yml
+++ b/.github/workflows/benchmark.yml
@@ -22,7 +22,7 @@ jobs:
runs-on: [single-gpu, nvidia-gpu, a10, ci]
container:
image: diffusers/diffusers-pytorch-compile-cuda
- options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --gpus 0
+ options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py
index 96d30c0837c8..c9932cc71c38 100644
--- a/benchmarks/run_all.py
+++ b/benchmarks/run_all.py
@@ -39,11 +39,8 @@ def main():
for file in python_files:
print(f"****** Running file: {file} ******")
- if "ip_adapters" in file:
- continue
-
# Run with canonical settings.
- if file != "benchmark_text_to_image.py":
+ if file != "benchmark_text_to_image.py" and file != "benchmark_ip_adapters.py":
command = f"python {file}"
run_command(command.split())
@@ -52,7 +49,8 @@ def main():
# Run variants.
for file in python_files:
- if "ip_adapters" in file:
+ # See: https://github.com/pytorch/pytorch/issues/129637
+ if file == "benchmark_ip_adapters.py":
continue
if file == "benchmark_text_to_image.py":
diff --git a/docs/source/en/api/models/autoencoderkl.md b/docs/source/en/api/models/autoencoderkl.md
index 158829a35b00..dd881089ad00 100644
--- a/docs/source/en/api/models/autoencoderkl.md
+++ b/docs/source/en/api/models/autoencoderkl.md
@@ -21,7 +21,7 @@ The abstract from the paper is:
## Loading from the original format
By default the [`AutoencoderKL`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
-from the original format using [`FromOriginalVAEMixin.from_single_file`] as follows:
+from the original format using [`FromOriginalModelMixin.from_single_file`] as follows:
```py
from diffusers import AutoencoderKL
diff --git a/docs/source/en/api/models/controlnet.md b/docs/source/en/api/models/controlnet.md
index b57620e1e414..c2fdf1c6f975 100644
--- a/docs/source/en/api/models/controlnet.md
+++ b/docs/source/en/api/models/controlnet.md
@@ -21,7 +21,7 @@ The abstract from the paper is:
## Loading from the original format
By default the [`ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
-from the original format using [`FromOriginalControlnetMixin.from_single_file`] as follows:
+from the original format using [`FromOriginalModelMixin.from_single_file`] as follows:
```py
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
diff --git a/docs/source/en/api/pipelines/hunyuandit.md b/docs/source/en/api/pipelines/hunyuandit.md
index 9ac5d90fedbf..250533837ed0 100644
--- a/docs/source/en/api/pipelines/hunyuandit.md
+++ b/docs/source/en/api/pipelines/hunyuandit.md
@@ -34,6 +34,12 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
+
+
+You can further improve generation quality by passing the generated image from [`HungyuanDiTPipeline`] to the [SDXL refiner](../../using-diffusers/sdxl#base-to-refiner-model) model.
+
+
+
## Optimization
You can optimize the pipeline's runtime and memory consumption with torch.compile and feed-forward chunking. To learn about other optimization methods, check out the [Speed up inference](../../optimization/fp16) and [Reduce memory usage](../../optimization/memory) guides.
diff --git a/docs/source/en/api/pipelines/pixart_sigma.md b/docs/source/en/api/pipelines/pixart_sigma.md
index 2bf69f1ecc6d..592ba0f374be 100644
--- a/docs/source/en/api/pipelines/pixart_sigma.md
+++ b/docs/source/en/api/pipelines/pixart_sigma.md
@@ -37,6 +37,12 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
+
+
+You can further improve generation quality by passing the generated image from [`PixArtSigmaPipeline`] to the [SDXL refiner](../../using-diffusers/sdxl#base-to-refiner-model) model.
+
+
+
## Inference with under 8GB GPU VRAM
Run the [`PixArtSigmaPipeline`] with under 8GB GPU VRAM by loading the text encoder in 8-bit precision. Let's walk through a full-fledged example.
diff --git a/docs/source/en/using-diffusers/other-formats.md b/docs/source/en/using-diffusers/other-formats.md
index a7f85cad96b0..6acd736b5f34 100644
--- a/docs/source/en/using-diffusers/other-formats.md
+++ b/docs/source/en/using-diffusers/other-formats.md
@@ -418,7 +418,7 @@ my_local_checkpoint_path = hf_hub_download(
my_local_config_path = snapshot_download(
repo_id="segmind/SSD-1B",
- allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
+ allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
)
pipeline = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
@@ -438,7 +438,7 @@ my_local_checkpoint_path = hf_hub_download(
my_local_config_path = snapshot_download(
repo_id="segmind/SSD-1B",
- allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
+ allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
local_dir="my_local_config"
)
@@ -468,7 +468,7 @@ print("My local checkpoint: ", my_local_checkpoint_path)
my_local_config_path = snapshot_download(
repo_id="segmind/SSD-1B",
- allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
+ allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
local_dir_use_symlinks=False,
)
print("My local config: ", my_local_config_path)
diff --git a/docs/source/en/using-diffusers/sdxl.md b/docs/source/en/using-diffusers/sdxl.md
index 6b9ab7f475e6..9938d561052b 100644
--- a/docs/source/en/using-diffusers/sdxl.md
+++ b/docs/source/en/using-diffusers/sdxl.md
@@ -285,6 +285,12 @@ refiner = DiffusionPipeline.from_pretrained(
).to("cuda")
```
+
+
+You can use SDXL refiner with a different base model. For example, you can use the [Hunyuan-DiT](../../api/pipelines/hunyuandit) or [PixArt-Sigma](../../api/pipelines/pixart_sigma) pipelines to generate images with better prompt adherence. Once you have generated an image, you can pass it to the SDXL refiner model to enhance final generation quality.
+
+
+
Generate an image from the base model, and set the model output to **latent** space:
```py
diff --git a/docs/source/ko/conceptual/philosophy.md b/docs/source/ko/conceptual/philosophy.md
index 8b18df713642..5d49c075a165 100644
--- a/docs/source/ko/conceptual/philosophy.md
+++ b/docs/source/ko/conceptual/philosophy.md
@@ -10,30 +10,30 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-# ์ฒ ํ
+# ์ฒ ํ [[philosophy]]
๐งจ Diffusers๋ ๋ค์ํ ๋ชจ๋ฌ๋ฆฌํฐ์์ **์ต์ ์** ์ฌ์ ํ๋ จ๋ diffusion ๋ชจ๋ธ์ ์ ๊ณตํฉ๋๋ค.
๊ทธ ๋ชฉ์ ์ ์ถ๋ก ๊ณผ ํ๋ จ์ ์ํ **๋ชจ๋์ ํด๋ฐ์ค**๋ก ์ฌ์ฉ๋๋ ๊ฒ์
๋๋ค.
-์ฐ๋ฆฌ๋ ์ค๋ ์๊ฐ์ ๊ฒฌ๋ ์ ์๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ๊ตฌ์ถํ๋ ๊ฒ์ ๋ชฉํ๋ก ํ๊ณ , ๋ฐ๋ผ์ API ์ค๊ณ๋ฅผ ๋งค์ฐ ์ค์์ํฉ๋๋ค.
+์ ํฌ๋ ์๊ฐ์ด ์ง๋๋ ๋ณ์น ์๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ๊ตฌ์ถํ๋ ๊ฒ์ ๋ชฉํ๋ก ํ๊ธฐ์ API ์ค๊ณ๋ฅผ ๋งค์ฐ ์ค์ํ๊ฒ ์๊ฐํฉ๋๋ค.
-๊ฐ๋จํ ๋งํด์, Diffusers๋ PyTorch์ ์์ฐ์ค๋ฌ์ด ํ์ฅ์ด ๋๋๋ก ๊ตฌ์ถ๋์์ต๋๋ค. ๋ฐ๋ผ์ ๋๋ถ๋ถ์ ์ค๊ณ ์ ํ์ [PyTorch์ ์ค๊ณ ์์น](https://pytorch.org/docs/stable/community/design.html#pytorch-design-philosophy)์ ๊ธฐ๋ฐํฉ๋๋ค. ์ด์ ๊ฐ์ฅ ์ค์ํ ๊ฒ๋ค์ ์ดํด๋ณด๊ฒ ์ต๋๋ค:
+๊ฐ๋จํ ๋งํด์, Diffusers๋ PyTorch๋ฅผ ์์ฐ์ค๋ฝ๊ฒ ํ์ฅํ ์ ์๋๋ก ๋ง๋ค์ด์ก์ต๋๋ค. ๋ฐ๋ผ์ ๋๋ถ๋ถ์ ์ค๊ณ ์ ํ์ [PyTorch์ ์ค๊ณ ์์น](https://pytorch.org/docs/stable/community/design.html#pytorch-design-philosophy)์ ๊ธฐ๋ฐํฉ๋๋ค. ์ด์ ๊ฐ์ฅ ์ค์ํ ๊ฒ๋ค์ ์ดํด๋ณด๊ฒ ์ต๋๋ค:
-## ์ฑ๋ฅ๋ณด๋ค๋ ์ฌ์ฉ์ฑ์
+## ์ฑ๋ฅ๋ณด๋ค๋ ์ฌ์ฉ์ฑ์ [[usability-over-performance]]
-- Diffusers๋ ๋ง์ ๋ด์ฅ ์ฑ๋ฅ ํฅ์ ๊ธฐ๋ฅ์ ๊ฐ๊ณ ์์ง๋ง (์์ธํ ๋ด์ฉ์ [๋ฉ๋ชจ๋ฆฌ์ ์๋](https://huggingface.co/docs/diffusers/optimization/fp16) ์ฐธ์กฐ), ๋ชจ๋ธ์ ํญ์ ๊ฐ์ฅ ๋์ ์ ๋ฐ๋์ ์ต์ํ์ ์ต์ ํ๋ก ๋ก๋๋ฉ๋๋ค. ๋ฐ๋ผ์ ๊ธฐ๋ณธ์ ์ธ diffusion ํ์ดํ๋ผ์ธ์ ๋ฐ๋ก ์ ์ํ์ง ์๋๋ค๋ฉด CPU์์ float32 ์ ๋ฐ๋๋ก ์ธ์คํด์คํ๋ฉ๋๋ค. ์ด๋ ๋ค์ํ ํ๋ซํผ๊ณผ ๊ฐ์๊ธฐ์์์ ์ฌ์ฉ์ฑ์ ๋ณด์ฅํ๋ฉฐ, ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์คํํ๊ธฐ ์ํด ๋ณต์กํ ์ค์น๊ฐ ํ์ํ์ง ์์์ ์๋ฏธํฉ๋๋ค.
+- Diffusers๋ ๋ค์ํ ์ฑ๋ฅ ํฅ์ ๊ธฐ๋ฅ์ด ๋ด์ฅ๋์ด ์์ง๋ง (์์ธํ ๋ด์ฉ์ [๋ฉ๋ชจ๋ฆฌ์ ์๋](https://huggingface.co/docs/diffusers/optimization/fp16) ์ฐธ์กฐ), ๋ชจ๋ธ์ ํญ์ ๊ฐ์ฅ ๋์ ์ ๋ฐ๋์ ์ต์ํ์ ์ต์ ํ๋ก ๋ก๋๋ฉ๋๋ค. ๋ฐ๋ผ์ ์ฌ์ฉ์๊ฐ ๋ณ๋๋ก ์ ์ํ์ง ์๋ ํ ๊ธฐ๋ณธ์ ์ผ๋ก diffusion ํ์ดํ๋ผ์ธ์ ํญ์ float32 ์ ๋ฐ๋๋ก CPU์ ์ธ์คํด์คํ๋ฉ๋๋ค. ์ด๋ ๋ค์ํ ํ๋ซํผ๊ณผ ๊ฐ์๊ธฐ์์์ ์ฌ์ฉ์ฑ์ ๋ณด์ฅํ๋ฉฐ, ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์คํํ๊ธฐ ์ํด ๋ณต์กํ ์ค์น๊ฐ ํ์ํ์ง ์๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.
- Diffusers๋ **๊ฐ๋ฒผ์ด** ํจํค์ง๋ฅผ ์งํฅํ๊ธฐ ๋๋ฌธ์ ํ์ ์ข
์์ฑ์ ๊ฑฐ์ ์์ง๋ง ์ฑ๋ฅ์ ํฅ์์ํฌ ์ ์๋ ๋ง์ ์ ํ์ ์ข
์์ฑ์ด ์์ต๋๋ค (`accelerate`, `safetensors`, `onnx` ๋ฑ). ์ ํฌ๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ๊ฐ๋ฅํ ํ ๊ฐ๋ณ๊ฒ ์ ์งํ์ฌ ๋ค๋ฅธ ํจํค์ง์ ๋ํ ์ข
์์ฑ ๊ฑฑ์ ์ด ์๋๋ก ๋
ธ๋ ฅํ๊ณ ์์ต๋๋ค.
- Diffusers๋ ๊ฐ๊ฒฐํ๊ณ ์ดํดํ๊ธฐ ์ฌ์ด ์ฝ๋๋ฅผ ์ ํธํฉ๋๋ค. ์ด๋ ๋๋ค ํจ์๋ ๊ณ ๊ธ PyTorch ์ฐ์ฐ์์ ๊ฐ์ ์์ถ๋ ์ฝ๋ ๊ตฌ๋ฌธ์ ์์ฃผ ์ฌ์ฉํ์ง ์๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.
-## ์ฌ์๋ณด๋ค๋ ๊ฐ๋จํจ์
+## ์ฌ์๋ณด๋ค๋ ๊ฐ๋จํจ์ [[simple-over-easy]]
PyTorch์์๋ **๋ช
์์ ์ธ ๊ฒ์ด ์์์ ์ธ ๊ฒ๋ณด๋ค ๋ซ๋ค**์ **๋จ์ํ ๊ฒ์ด ๋ณต์กํ ๊ฒ๋ณด๋ค ๋ซ๋ค**๋ผ๊ณ ๋งํฉ๋๋ค. ์ด ์ค๊ณ ์ฒ ํ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์ฌ๋ฌ ๋ถ๋ถ์ ๋ฐ์๋์ด ์์ต๋๋ค:
-- [`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to)์ ๊ฐ์ ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ์ฉ์๊ฐ ์ฅ์น ๊ด๋ฆฌ๋ฅผ ํ ์ ์๋๋ก PyTorch์ API๋ฅผ ๋ฐ๋ฆ
๋๋ค.
+- [`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to)์ ๊ฐ์ ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ์ฉ์๊ฐ ์ฅ์น ๊ด๋ฆฌ๋ฅผ ํ ์ ์๋๋ก PyTorch์ API๋ฅผ ๋ฐ๋ฆ
๋๋ค.
- ์๋ชป๋ ์
๋ ฅ์ ์กฐ์ฉํ ์์ ํ๋ ๋์ ๊ฐ๊ฒฐํ ์ค๋ฅ ๋ฉ์์ง๋ฅผ ๋ฐ์์ํค๋ ๊ฒ์ด ์ฐ์ ์
๋๋ค. Diffusers๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ๊ฐ๋ฅํ ํ ์ฝ๊ฒ ์ฌ์ฉํ ์ ์๋๋ก ํ๋ ๊ฒ๋ณด๋ค ์ฌ์ฉ์๋ฅผ ๊ฐ๋ฅด์น๋ ๊ฒ์ ๋ชฉํ๋ก ํฉ๋๋ค.
- ๋ณต์กํ ๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ ๋ก์ง์ด ๋ด๋ถ์์ ๋ง๋ฒ์ฒ๋ผ ์ฒ๋ฆฌํ๋ ๋์ ๋
ธ์ถ๋ฉ๋๋ค. ์ค์ผ์ค๋ฌ/์ํ๋ฌ๋ ์๋ก์๊ฒ ์ต์ํ์ ์ข
์์ฑ์ ๊ฐ์ง๊ณ ๋ถ๋ฆฌ๋์ด ์์ต๋๋ค. ์ด๋ก์จ ์ฌ์ฉ์๋ ์ธ๋กค๋ ๋
ธ์ด์ฆ ์ ๊ฑฐ ๋ฃจํ๋ฅผ ์์ฑํด์ผ ํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ด ๋ถ๋ฆฌ๋ ๋๋ฒ๊น
์ ๋ ์ฝ๊ฒํ๊ณ ๋
ธ์ด์ฆ ์ ๊ฑฐ ๊ณผ์ ์ ์กฐ์ ํ๊ฑฐ๋ diffusers ๋ชจ๋ธ์ด๋ ์ค์ผ์ค๋ฌ๋ฅผ ๊ต์ฒดํ๋ ๋ฐ ์ฌ์ฉ์์๊ฒ ๋ ๋ง์ ์ ์ด๊ถ์ ์ ๊ณตํฉ๋๋ค.
- diffusers ํ์ดํ๋ผ์ธ์ ๋ฐ๋ก ํ๋ จ๋ ๊ตฌ์ฑ ์์์ธ text encoder, unet ๋ฐ variational autoencoder๋ ๊ฐ๊ฐ ์์ฒด ๋ชจ๋ธ ํด๋์ค๋ฅผ ๊ฐ์ต๋๋ค. ์ด๋ก์จ ์ฌ์ฉ์๋ ์๋ก ๋ค๋ฅธ ๋ชจ๋ธ์ ๊ตฌ์ฑ ์์ ๊ฐ์ ์ํธ ์์ฉ์ ์ฒ๋ฆฌํด์ผ ํ๋ฉฐ, ์ง๋ ฌํ ํ์์ ๋ชจ๋ธ ๊ตฌ์ฑ ์์๋ฅผ ๋ค๋ฅธ ํ์ผ๋ก ๋ถ๋ฆฌํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ด๋ ๋๋ฒ๊น
๊ณผ ์ปค์คํฐ๋ง์ด์ง์ ๋ ์ฝ๊ฒํฉ๋๋ค. DreamBooth๋ Textual Inversion ํ๋ จ์ Diffusers์ 'diffusion ํ์ดํ๋ผ์ธ์ ๋จ์ผ ๊ตฌ์ฑ ์์๋ค์ ๋ถ๋ฆฌํ ์ ์๋ ๋ฅ๋ ฅ' ๋๋ถ์ ๋งค์ฐ ๊ฐ๋จํฉ๋๋ค.
-## ์ถ์ํ๋ณด๋ค๋ ์์ ๊ฐ๋ฅํ๊ณ ๊ธฐ์ฌํ๊ธฐ ์ฌ์์
+## ์ถ์ํ๋ณด๋ค๋ ์์ ๊ฐ๋ฅํ๊ณ ๊ธฐ์ฌํ๊ธฐ ์ฌ์์ [[tweakable-contributor-friendly-over-abstraction]]
๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋๋ถ๋ถ์ ๋ํด Diffusers๋ [Transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ](https://github.com/huggingface/transformers)์ ์ค์ํ ์ค๊ณ ์์น์ ์ฑํํฉ๋๋ค, ๋ฐ๋ก ์ฑ๊ธํ ์ถ์ํ๋ณด๋ค๋ copy-pasted ์ฝ๋๋ฅผ ์ ํธํ๋ค๋ ๊ฒ์
๋๋ค. ์ด ์ค๊ณ ์์น์ [Don't repeat yourself (DRY)](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself)์ ๊ฐ์ ์ธ๊ธฐ ์๋ ์ค๊ณ ์์น๊ณผ๋ ๋์กฐ์ ์ผ๋ก ๋งค์ฐ ์๊ฒฌ์ด ๋ถ๋ถํ๋ฐ์.
๊ฐ๋จํ ๋งํด์, Transformers๊ฐ ๋ชจ๋ธ๋ง ํ์ผ์ ๋ํด ์ํํ๋ ๊ฒ์ฒ๋ผ, Diffusers๋ ๋งค์ฐ ๋ฎ์ ์์ค์ ์ถ์ํ์ ๋งค์ฐ ๋
๋ฆฝ์ ์ธ ์ฝ๋๋ฅผ ์ ์งํ๋ ๊ฒ์ ์ ํธํฉ๋๋ค. ํจ์, ๊ธด ์ฝ๋ ๋ธ๋ก, ์ฌ์ง์ด ํด๋์ค๋ ์ฌ๋ฌ ํ์ผ์ ๋ณต์ฌํ ์ ์์ผ๋ฉฐ, ์ด๋ ์ฒ์์๋ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ ์งํ ์ ์๊ฒ ๋ง๋๋ ๋์, ์ํฌ๋ฅธ ์ค๊ณ ์ ํ์ผ๋ก ๋ณด์ผ ์ ์์ต๋๋ค. ํ์ง๋ง ์ด๋ฌํ ์ค๊ณ๋ ๋งค์ฐ ์ฑ๊ณต์ ์ด๋ฉฐ, ์ปค๋ฎค๋ํฐ ๊ธฐ๋ฐ์ ์คํ ์์ค ๊ธฐ๊ณ ํ์ต ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋งค์ฐ ์ ํฉํฉ๋๋ค. ๊ทธ ์ด์ ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
@@ -48,11 +48,11 @@ Diffusers์์๋ ์ด๋ฌํ ์ฒ ํ์ ํ์ดํ๋ผ์ธ๊ณผ ์ค์ผ์ค๋ฌ์ ๋ชจ๋
์ข์์, ์ด์ ๐งจ Diffusers๊ฐ ์ค๊ณ๋ ๋ฐฉ์์ ๋๋ต์ ์ผ๋ก ์ดํดํ์ ๊ฒ์
๋๋ค ๐ค.
์ฐ๋ฆฌ๋ ์ด๋ฌํ ์ค๊ณ ์์น์ ์ผ๊ด๋๊ฒ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ ์ฒด์ ์ ์ฉํ๋ ค๊ณ ๋
ธ๋ ฅํ๊ณ ์์ต๋๋ค. ๊ทธ๋ผ์๋ ๋ถ๊ตฌํ๊ณ ์ฒ ํ์ ๋ํ ์ผ๋ถ ์์ธ ์ฌํญ์ด๋ ๋ถํํ ์ค๊ณ ์ ํ์ด ์์ ์ ์์ต๋๋ค. ๋์์ธ์ ๋ํ ํผ๋๋ฐฑ์ด ์๋ค๋ฉด [GitHub์์ ์ง์ ](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) ์๋ ค์ฃผ์๋ฉด ๊ฐ์ฌํ๊ฒ ์ต๋๋ค.
-## ๋์์ธ ์ฒ ํ ์์ธํ ์์๋ณด๊ธฐ
+## ๋์์ธ ์ฒ ํ ์์ธํ ์์๋ณด๊ธฐ [[design-philosophy-in-details]]
์ด์ ๋์์ธ ์ฒ ํ์ ์ธ๋ถ ์ฌํญ์ ์ข ๋ ์์ธํ ์ดํด๋ณด๊ฒ ์ต๋๋ค. Diffusers๋ ์ฃผ๋ก ์ธ ๊ฐ์ง ์ฃผ์ ํด๋์ค๋ก ๊ตฌ์ฑ๋ฉ๋๋ค: [ํ์ดํ๋ผ์ธ](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines), [๋ชจ๋ธ](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models), ๊ทธ๋ฆฌ๊ณ [์ค์ผ์ค๋ฌ](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers). ๊ฐ ํด๋์ค์ ๋ํ ๋ ์์ธํ ์ค๊ณ ๊ฒฐ์ ์ฌํญ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
-### ํ์ดํ๋ผ์ธ
+### ํ์ดํ๋ผ์ธ [[pipelines]]
ํ์ดํ๋ผ์ธ์ ์ฌ์ฉํ๊ธฐ ์ฝ๋๋ก ์ค๊ณ๋์์ผ๋ฉฐ (๋ฐ๋ผ์ [*์ฌ์๋ณด๋ค๋ ๊ฐ๋จํจ์*](#์ฌ์๋ณด๋ค๋-๊ฐ๋จํจ์)์ 100% ๋ฐ๋ฅด์ง๋ ์์), feature-completeํ์ง ์์ผ๋ฉฐ, ์ถ๋ก ์ ์ํ [๋ชจ๋ธ](#๋ชจ๋ธ)๊ณผ [์ค์ผ์ค๋ฌ](#์ค์ผ์ค๋ฌ)๋ฅผ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ์์๋ก ๊ฐ์ฃผ๋ ์ ์์ต๋๋ค.
@@ -65,11 +65,11 @@ Diffusers์์๋ ์ด๋ฌํ ์ฒ ํ์ ํ์ดํ๋ผ์ธ๊ณผ ์ค์ผ์ค๋ฌ์ ๋ชจ๋
- ํ์ดํ๋ผ์ธ์ ๋งค์ฐ ๊ฐ๋
์ฑ์ด ์ข๊ณ , ์ดํดํ๊ธฐ ์ฝ๊ณ , ์ฝ๊ฒ ์กฐ์ ํ ์ ์๋๋ก ์ค๊ณ๋์ด์ผ ํฉ๋๋ค.
- ํ์ดํ๋ผ์ธ์ ์๋ก ์ํธ์์ฉํ๊ณ , ์์ ์์ค API์ ์ฝ๊ฒ ํตํฉํ ์ ์๋๋ก ์ค๊ณ๋์ด์ผ ํฉ๋๋ค.
- ํ์ดํ๋ผ์ธ์ ์ฌ์ฉ์ ์ธํฐํ์ด์ค๊ฐ feature-completeํ์ง ์๊ฒ ํ๋ ๊ฒ์ ๋ชฉํ๋ก ํฉ๋๋ค. future-completeํ ์ฌ์ฉ์ ์ธํฐํ์ด์ค๋ฅผ ์ํ๋ค๋ฉด [InvokeAI](https://github.com/invoke-ai/InvokeAI), [Diffuzers](https://github.com/abhishekkrthakur/diffuzers), [lama-cleaner](https://github.com/Sanster/lama-cleaner)๋ฅผ ์ฐธ์กฐํด์ผ ํฉ๋๋ค.
-- ๋ชจ๋ ํ์ดํ๋ผ์ธ์ ์ค๋ก์ง `__call__` ๋ฉ์๋๋ฅผ ํตํด ์คํํ ์ ์์ด์ผ ํฉ๋๋ค. `__call__` ์ธ์์ ์ด๋ฆ์ ๋ชจ๋ ํ์ดํ๋ผ์ธ์์ ๊ณต์ ๋์ด์ผ ํฉ๋๋ค.
+- ๋ชจ๋ ํ์ดํ๋ผ์ธ์ ์ค๋ก์ง `__call__` ๋ฉ์๋๋ฅผ ํตํด ์คํํ ์ ์์ด์ผ ํฉ๋๋ค. `__call__` ์ธ์์ ์ด๋ฆ์ ๋ชจ๋ ํ์ดํ๋ผ์ธ์์ ๊ณต์ ๋์ด์ผ ํฉ๋๋ค.
- ํ์ดํ๋ผ์ธ์ ํด๊ฒฐํ๊ณ ์ ํ๋ ์์
์ ์ด๋ฆ์ผ๋ก ์ง์ ๋์ด์ผ ํฉ๋๋ค.
- ๋๋ถ๋ถ์ ๊ฒฝ์ฐ์ ์๋ก์ด diffusion ํ์ดํ๋ผ์ธ์ ์๋ก์ด ํ์ดํ๋ผ์ธ ํด๋/ํ์ผ์ ๊ตฌํ๋์ด์ผ ํฉ๋๋ค.
-### ๋ชจ๋ธ
+### ๋ชจ๋ธ [[models]]
๋ชจ๋ธ์ [PyTorch์ Module ํด๋์ค](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)์ ์์ฐ์ค๋ฌ์ด ํ์ฅ์ด ๋๋๋ก, ๊ตฌ์ฑ ๊ฐ๋ฅํ ํด๋ฐ์ค๋ก ์ค๊ณ๋์์ต๋๋ค. ๊ทธ๋ฆฌ๊ณ ๋ชจ๋ธ์ **๋จ์ผ ํ์ผ ์ ์ฑ
**์ ์ผ๋ถ๋ง ๋ฐ๋ฆ
๋๋ค.
@@ -85,7 +85,7 @@ Diffusers์์๋ ์ด๋ฌํ ์ฒ ํ์ ํ์ดํ๋ผ์ธ๊ณผ ์ค์ผ์ค๋ฌ์ ๋ชจ๋
- ๋ชจ๋ธ์ ๋ฏธ๋์ ๋ณ๊ฒฝ ์ฌํญ์ ์ฝ๊ฒ ํ์ฅํ ์ ์๋๋ก ์ค๊ณ๋์ด์ผ ํฉ๋๋ค. ์ด๋ ๊ณต๊ฐ ํจ์ ์ธ์๋ค๊ณผ ๊ตฌ์ฑ ์ธ์๋ค์ ์ ํํ๊ณ ,๋ฏธ๋์ ๋ณ๊ฒฝ ์ฌํญ์ "์์"ํ๋ ๊ฒ์ ํตํด ๋ฌ์ฑํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ๋ถ๋ฆฌ์ธ `is_..._type` ์ธ์๋ณด๋ค๋ ์๋ก์ด ๋ฏธ๋ ์ ํ์ ์ฝ๊ฒ ํ์ฅํ ์ ์๋ ๋ฌธ์์ด "...type" ์ธ์๋ฅผ ์ถ๊ฐํ๋ ๊ฒ์ด ์ผ๋ฐ์ ์ผ๋ก ๋ ์ข์ต๋๋ค. ์๋ก์ด ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๊ฐ ์๋ํ๋๋ก ํ๊ธฐ ์ํด ๊ธฐ์กด ์ํคํ
์ฒ์ ์ต์ํ์ ๋ณ๊ฒฝ๋ง์ ๊ฐํด์ผ ํฉ๋๋ค.
- ๋ชจ๋ธ ๋์์ธ์ ์ฝ๋์ ๊ฐ๋
์ฑ๊ณผ ๊ฐ๊ฒฐ์ฑ์ ์ ์งํ๋ ๊ฒ๊ณผ ๋ง์ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ง์ํ๋ ๊ฒ ์ฌ์ด์ ์ด๋ ค์ด ๊ท ํ ์กฐ์ ์
๋๋ค. ๋ชจ๋ธ๋ง ์ฝ๋์ ๋๋ถ๋ถ์ ์๋ก์ด ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ํด ํด๋์ค๋ฅผ ์์ ํ๋ ๊ฒ์ด ์ข์ง๋ง, [UNet ๋ธ๋ก](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py) ๋ฐ [Attention ํ๋ก์ธ์](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)์ ๊ฐ์ด ์ฝ๋๋ฅผ ์ฅ๊ธฐ์ ์ผ๋ก ๊ฐ๊ฒฐํ๊ณ ์ฝ๊ธฐ ์ฝ๊ฒ ์ ์งํ๊ธฐ ์ํด ์๋ก์ด ํด๋์ค๋ฅผ ์ถ๊ฐํ๋ ์์ธ๋ ์์ต๋๋ค.
-### ์ค์ผ์ค๋ฌ
+### ์ค์ผ์ค๋ฌ [[schedulers]]
์ค์ผ์ค๋ฌ๋ ์ถ๋ก ์ ์ํ ๋
ธ์ด์ฆ ์ ๊ฑฐ ๊ณผ์ ์ ์๋ดํ๊ณ ํ๋ จ์ ์ํ ๋
ธ์ด์ฆ ์ค์ผ์ค์ ์ ์ํ๋ ์ญํ ์ ํฉ๋๋ค. ์ค์ผ์ค๋ฌ๋ ๊ฐ๋ณ ํด๋์ค๋ก ์ค๊ณ๋์ด ์์ผ๋ฉฐ, ๋ก๋ ๊ฐ๋ฅํ ๊ตฌ์ฑ ํ์ผ๊ณผ **๋จ์ผ ํ์ผ ์ ์ฑ
**์ ์๊ฒฉํ ๋ฐ๋ฆ
๋๋ค.
@@ -95,7 +95,7 @@ Diffusers์์๋ ์ด๋ฌํ ์ฒ ํ์ ํ์ดํ๋ผ์ธ๊ณผ ์ค์ผ์ค๋ฌ์ ๋ชจ๋
- ํ๋์ ์ค์ผ์ค๋ฌ Python ํ์ผ์ ํ๋์ ์ค์ผ์ค๋ฌ ์๊ณ ๋ฆฌ์ฆ(๋
ผ๋ฌธ์์ ์ ์๋ ๊ฒ๊ณผ ๊ฐ์)์ ํด๋นํฉ๋๋ค.
- ์ค์ผ์ค๋ฌ๊ฐ ์ ์ฌํ ๊ธฐ๋ฅ์ ๊ณต์ ํ๋ ๊ฒฝ์ฐ, `# Copied from` ๋ฉ์ปค๋์ฆ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
- ๋ชจ๋ ์ค์ผ์ค๋ฌ๋ `SchedulerMixin`๊ณผ `ConfigMixin`์ ์์ํฉ๋๋ค.
-- [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ์ค์ผ์ค๋ฌ๋ฅผ ์ฝ๊ฒ ๊ต์ฒดํ ์ ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ [์ฌ๊ธฐ](../using-diffusers/schedulers.md)์์ ์ค๋ช
ํฉ๋๋ค.
+- [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ์ค์ผ์ค๋ฌ๋ฅผ ์ฝ๊ฒ ๊ต์ฒดํ ์ ์์ต๋๋ค. ์์ธํ ๋ด์ฉ์ [์ฌ๊ธฐ](../using-diffusers/schedulers.md)์์ ์ค๋ช
ํฉ๋๋ค.
- ๋ชจ๋ ์ค์ผ์ค๋ฌ๋ `set_num_inference_steps`์ `step` ํจ์๋ฅผ ๊ฐ์ ธ์ผ ํฉ๋๋ค. `set_num_inference_steps(...)`๋ ๊ฐ ๋
ธ์ด์ฆ ์ ๊ฑฐ ๊ณผ์ (์ฆ, `step(...)`์ด ํธ์ถ๋๊ธฐ ์ ) ์ด์ ์ ํธ์ถ๋์ด์ผ ํฉ๋๋ค.
- ๊ฐ ์ค์ผ์ค๋ฌ๋ ๋ชจ๋ธ์ด ํธ์ถ๋ ํ์์คํ
์ ๋ฐฐ์ด์ธ `timesteps` ์์ฑ์ ํตํด ๋ฃจํ๋ฅผ ๋ ์ ์๋ ํ์์คํ
์ ๋
ธ์ถํฉ๋๋ค.
- `step(...)` ํจ์๋ ์์ธก๋ ๋ชจ๋ธ ์ถ๋ ฅ๊ณผ "ํ์ฌ" ์ํ(x_t)์ ์
๋ ฅ์ผ๋ก ๋ฐ๊ณ , "์ด์ " ์ฝ๊ฐ ๋ ๋
ธ์ด์ฆ๊ฐ ์ ๊ฑฐ๋ ์ํ(x_t-1)์ ๋ฐํํฉ๋๋ค.
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
index 4d442b62332f..cf558f082018 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
@@ -1290,6 +1290,7 @@ def save_model_hook(models, weights, output_dir):
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
+ else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
@@ -1524,17 +1525,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
torch.cuda.empty_cache()
# Scheduler and math around the number of training steps.
- overrode_max_train_steps = False
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- overrode_max_train_steps = True
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
- num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1551,8 +1557,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if overrode_max_train_steps:
+ if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -1845,10 +1857,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}
- if torch.backends.mps.is_available():
- autocast_ctx = nullcontext()
- else:
- autocast_ctx = torch.autocast(accelerator.device.type)
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
images = [
@@ -1869,7 +1881,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
]
}
)
-
del pipeline
torch.cuda.empty_cache()
@@ -1971,7 +1982,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
- save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
+ save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
save_model_card(
model_id if not args.push_to_hub else repo_id,
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
index d54d9f1b2402..9d06ce6cba16 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
@@ -573,6 +573,13 @@ def parse_args(input_args=None):
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
+ parser.add_argument(
+ "--clip_skip",
+ type=int,
+ default=None,
+ help="Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that "
+ "the output of the pre-final layer will be used for computing the prompt embeddings.",
+ )
parser.add_argument(
"--text_encoder_lr",
@@ -1236,7 +1243,7 @@ def tokenize_prompt(tokenizer, prompt, add_special_tokens=False):
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
-def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
+def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None, clip_skip=None):
prompt_embeds_list = []
for i, text_encoder in enumerate(text_encoders):
@@ -1253,7 +1260,11 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
- prompt_embeds = prompt_embeds[-1][-2]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds[-1][-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 2)]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
@@ -1830,9 +1841,9 @@ def compute_time_ids(crops_coords_top_left, original_size=None):
tokenizers = [tokenizer_one, tokenizer_two]
text_encoders = [text_encoder_one, text_encoder_two]
- def compute_text_embeddings(prompt, text_encoders, tokenizers):
+ def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip):
with torch.no_grad():
- prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt, clip_skip)
prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds
@@ -1842,7 +1853,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
# the redundant encoding.
if freeze_text_encoder and not train_dataset.custom_instance_prompts:
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
- args.instance_prompt, text_encoders, tokenizers
+ args.instance_prompt, text_encoders, tokenizers, args.clip_skip
)
# Handle class prompt for prior-preservation.
@@ -1899,17 +1910,22 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
torch.cuda.empty_cache()
# Scheduler and math around the number of training steps.
- overrode_max_train_steps = False
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- overrode_max_train_steps = True
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
- num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1926,8 +1942,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if overrode_max_train_steps:
+ if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -2041,7 +2063,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if train_dataset.custom_instance_prompts:
if freeze_text_encoder:
prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
- prompts, text_encoders, tokenizers
+ prompts, text_encoders, tokenizers, args.clip_skip
)
else:
@@ -2136,6 +2158,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
tokenizers=None,
prompt=None,
text_input_ids_list=[tokens_one, tokens_two],
+ clip_skip=args.clip_skip,
)
unet_added_conditions.update(
{"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
@@ -2402,7 +2425,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
- save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
+ save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
save_model_card(
model_id if not args.push_to_hub else repo_id,
diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py
index c05a70507af3..dee79bf9f190 100644
--- a/examples/controlnet/train_controlnet_sdxl.py
+++ b/examples/controlnet/train_controlnet_sdxl.py
@@ -1088,17 +1088,22 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
)
# Scheduler and math around the number of training steps.
- overrode_max_train_steps = False
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- overrode_max_train_steps = True
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
- num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1110,8 +1115,14 @@ def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizer
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if overrode_max_train_steps:
+ if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
diff --git a/examples/dreambooth/test_dreambooth_lora_sd3.py b/examples/dreambooth/test_dreambooth_lora_sd3.py
new file mode 100644
index 000000000000..518738b78246
--- /dev/null
+++ b/examples/dreambooth/test_dreambooth_lora_sd3.py
@@ -0,0 +1,165 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+import tempfile
+
+import safetensors
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRASD3(ExamplesTestsAccelerate):
+ instance_data_dir = "docs/source/en/imgs"
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
+ script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py"
+
+ def test_dreambooth_lora_sd3(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_text_encoder_sd3(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --train_text_encoder
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ starts_with_expected_prefix = all(
+ (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
+ )
+ self.assertTrue(starts_with_expected_prefix)
+
+ def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
+
+ resume_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
diff --git a/examples/dreambooth/test_dreambooth_sd3.py b/examples/dreambooth/test_dreambooth_sd3.py
new file mode 100644
index 000000000000..19fb7243cefd
--- /dev/null
+++ b/examples/dreambooth/test_dreambooth_sd3.py
@@ -0,0 +1,203 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import shutil
+import sys
+import tempfile
+
+from diffusers import DiffusionPipeline, SD3Transformer2DModel
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothSD3(ExamplesTestsAccelerate):
+ instance_data_dir = "docs/source/en/imgs"
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
+ script_path = "examples/dreambooth/train_dreambooth_sd3.py"
+
+ def test_dreambooth(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors")))
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
+
+ def test_dreambooth_checkpointing(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Run training script with checkpointing
+ # max_train_steps == 4, checkpointing_steps == 2
+ # Should create checkpoints at steps 2, 4
+
+ initial_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 4
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + initial_run_args)
+
+ # check can run the original fully trained output pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir)
+ pipe(self.instance_prompt, num_inference_steps=1)
+
+ # check checkpoint directories exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+
+ # check can run an intermediate checkpoint
+ transformer = SD3Transformer2DModel.from_pretrained(tmpdir, subfolder="checkpoint-2/transformer")
+ pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer)
+ pipe(self.instance_prompt, num_inference_steps=1)
+
+ # Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
+ shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
+
+ # Run training script for 7 total steps resuming from checkpoint 4
+
+ resume_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 6
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --seed=0
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ # check can run new fully trained pipeline
+ pipe = DiffusionPipeline.from_pretrained(tmpdir)
+ pipe(self.instance_prompt, num_inference_steps=1)
+
+ # check old checkpoints do not exist
+ self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
+
+ # check new checkpoints exist
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
+ self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-2", "checkpoint-4"},
+ )
+
+ resume_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py
index 2c66c341f78f..5401ee570a34 100644
--- a/examples/dreambooth/train_dreambooth_lora_sd3.py
+++ b/examples/dreambooth/train_dreambooth_lora_sd3.py
@@ -101,19 +101,37 @@ def save_model_card(
## Model description
-These are {repo_id} DreamBooth weights for {base_model}.
+These are {repo_id} DreamBooth LoRA weights for {base_model}.
-The weights were trained using [DreamBooth](https://dreambooth.github.io/).
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sd3.md).
-LoRA for the text encoder was enabled: {train_text_encoder}.
+Was LoRA for the text encoder enabled? {train_text_encoder}.
## Trigger words
-You should use {instance_prompt} to trigger the image generation.
+You should use `{instance_prompt}` to trigger the image generation.
## Download model
-[Download]({repo_id}/tree/main) them in the Files & versions tab.
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+## Use it with the [๐งจ diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers', torch_dtype=torch.float16).to('cuda')
+pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
+image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
+```
+
+### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
+
+- **LoRA**: download **[`diffusers_lora_weights.safetensors` here ๐พ](/{repo_id}/blob/main/diffusers_lora_weights.safetensors)**.
+ - Rename it and place it on your `models/Lora` folder.
+ - On AUTOMATIC1111, load the LoRA by adding `` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
## License
@@ -962,7 +980,7 @@ def encode_prompt(
prompt=prompt,
device=device if device is not None else text_encoder.device,
num_images_per_prompt=num_images_per_prompt,
- text_input_ids=text_input_ids_list[i],
+ text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
)
clip_prompt_embeds_list.append(prompt_embeds)
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
@@ -976,7 +994,7 @@ def encode_prompt(
max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
- text_input_ids=text_input_ids_list[:-1],
+ text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
device=device if device is not None else text_encoders[-1].device,
)
@@ -1491,6 +1509,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
) = accelerator.prepare(
transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
)
+ assert text_encoder_one is not None
+ assert text_encoder_two is not None
+ assert text_encoder_three is not None
else:
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
@@ -1598,7 +1619,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
tokens_three = tokenize_prompt(tokenizer_three, prompts)
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
- tokenizers=[None, None, tokenizer_three],
+ tokenizers=[None, None, None],
prompt=prompts,
max_sequence_length=args.max_sequence_length,
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
@@ -1608,7 +1629,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
tokenizers=[None, None, tokenizer_three],
- prompt=prompts,
+ prompt=args.instance_prompt,
max_sequence_length=args.max_sequence_length,
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
)
@@ -1685,10 +1706,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
accelerator.backward(loss)
if accelerator.sync_gradients:
- params_to_clip = itertools.chain(
- transformer_lora_parameters,
- text_lora_parameters_one,
- text_lora_parameters_two if args.train_text_encoder else transformer_lora_parameters,
+ params_to_clip = (
+ itertools.chain(
+ transformer_lora_parameters, text_lora_parameters_one, text_lora_parameters_two
+ )
+ if args.train_text_encoder
+ else transformer_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
@@ -1741,13 +1764,6 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
)
- else:
- text_encoder_three = text_encoder_cls_three.from_pretrained(
- args.pretrained_model_name_or_path,
- subfolder="text_encoder_3",
- revision=args.revision,
- variant=args.variant,
- )
pipeline = StableDiffusion3Pipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
@@ -1767,7 +1783,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
pipeline_args=pipeline_args,
epoch=epoch,
)
- del text_encoder_one, text_encoder_two, text_encoder_three
+ if not args.train_text_encoder:
+ del text_encoder_one, text_encoder_two, text_encoder_three
+
torch.cuda.empty_cache()
gc.collect()
diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py
index c8f2fb1ac61b..9a72294c20bd 100644
--- a/examples/dreambooth/train_dreambooth_sd3.py
+++ b/examples/dreambooth/train_dreambooth_sd3.py
@@ -95,17 +95,22 @@ def save_model_card(
These are {repo_id} DreamBooth weights for {base_model}.
-The weights were trained using [DreamBooth](https://dreambooth.github.io/).
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sd3.md).
-Text encoder was fine-tuned: {train_text_encoder}.
+Was the text encoder fine-tuned? {train_text_encoder}.
## Trigger words
-You should use {instance_prompt} to trigger the image generation.
+You should use `{instance_prompt}` to trigger the image generation.
-## Download model
+## Use it with the [๐งจ diffusers library](https://github.com/huggingface/diffusers)
-[Download]({repo_id}/tree/main) them in the Files & versions tab.
+```py
+from diffusers import AutoPipelineForText2Image
+import torch
+pipeline = AutoPipelineForText2Image.from_pretrained('{repo_id}', torch_dtype=torch.float16).to('cuda')
+image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
+```
## License
diff --git a/examples/research_projects/sd3_lora_colab/README.md b/examples/research_projects/sd3_lora_colab/README.md
new file mode 100644
index 000000000000..d90a1c9f0ae2
--- /dev/null
+++ b/examples/research_projects/sd3_lora_colab/README.md
@@ -0,0 +1,38 @@
+# Running Stable Diffusion 3 DreamBooth LoRA training under 16GB
+
+This is an **EDUCATIONAL** project that provides utilities for DreamBooth LoRA training for [Stable Diffusion 3 (SD3)](ttps://huggingface.co/papers/2403.03206) under 16GB GPU VRAM. This means you can successfully try out this project using a [free-tier Colab Notebook](https://colab.research.google.com/github/huggingface/diffusers/blob/main/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb) instance. ๐ค
+
+> [!NOTE]
+> SD3 is gated, so you need to make sure you agree to [share your contact info](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) to access the model before using it with Diffusers. Once you have access, you need to log in so your system knows youโre authorized. Use the command below to log in:
+
+```bash
+huggingface-cli login
+```
+
+This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
+
+For setup, inference code, and details on how to run the code, please follow the Colab Notebook provided above.
+
+## How
+
+We make use of several techniques to make this possible:
+
+* Compute the embeddings from the instance prompt and serialize them for later reuse. This is implemented in the [`compute_embeddings.py`](./compute_embeddings.py) script. We use an 8bit (as introduced in [`LLM.int8()`](https://arxiv.org/abs/2208.07339)) T5 to reduce memory requirements to ~10.5GB.
+* In the `train_dreambooth_sd3_lora_miniature.py` script, we make use of:
+ * 8bit Adam for optimization through the `bitsandbytes` library.
+ * Gradient checkpointing and gradient accumulation.
+ * FP16 precision.
+ * Flash attention through `F.scaled_dot_product_attention()`.
+
+Computing the text embeddings is arguably the most memory-intensive part in the pipeline as SD3 employs three text encoders. If we run them in FP32, it will take about 20GB of VRAM. With FP16, we are down to 12GB.
+
+
+## Gotchas
+
+This project is educational. It exists to showcase the possibility of fine-tuning a big diffusion system on consumer GPUs. But additional components might have to be added to obtain state-of-the-art performance. Below are some commonly known gotchas that users should be aware of:
+
+* Training of text encoders is purposefully disabled.
+* Techniques such as prior-preservation is unsupported.
+* Custom instance captions for instance images are unsupported, but this should be relatively easy to integrate.
+
+Hopefully, this project gives you a template to extend it further to suit your needs.
diff --git a/examples/research_projects/sd3_lora_colab/compute_embeddings.py b/examples/research_projects/sd3_lora_colab/compute_embeddings.py
new file mode 100644
index 000000000000..5014752ffe34
--- /dev/null
+++ b/examples/research_projects/sd3_lora_colab/compute_embeddings.py
@@ -0,0 +1,123 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import glob
+import hashlib
+
+import pandas as pd
+import torch
+from transformers import T5EncoderModel
+
+from diffusers import StableDiffusion3Pipeline
+
+
+PROMPT = "a photo of sks dog"
+MAX_SEQ_LENGTH = 77
+LOCAL_DATA_DIR = "dog"
+OUTPUT_PATH = "sample_embeddings.parquet"
+
+
+def bytes_to_giga_bytes(bytes):
+ return bytes / 1024 / 1024 / 1024
+
+
+def generate_image_hash(image_path):
+ with open(image_path, "rb") as f:
+ img_data = f.read()
+ return hashlib.sha256(img_data).hexdigest()
+
+
+def load_sd3_pipeline():
+ id = "stabilityai/stable-diffusion-3-medium-diffusers"
+ text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_3", load_in_8bit=True, device_map="auto")
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
+ id, text_encoder_3=text_encoder, transformer=None, vae=None, device_map="balanced"
+ )
+ return pipeline
+
+
+@torch.no_grad()
+def compute_embeddings(pipeline, prompt, max_sequence_length):
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None, max_sequence_length=max_sequence_length)
+
+ print(
+ f"{prompt_embeds.shape=}, {negative_prompt_embeds.shape=}, {pooled_prompt_embeds.shape=}, {negative_pooled_prompt_embeds.shape}"
+ )
+
+ max_memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated())
+ print(f"Max memory allocated: {max_memory:.3f} GB")
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+
+def run(args):
+ pipeline = load_sd3_pipeline()
+ prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = compute_embeddings(
+ pipeline, args.prompt, args.max_sequence_length
+ )
+
+ # Assumes that the images within `args.local_image_dir` have a JPEG extension. Change
+ # as needed.
+ image_paths = glob.glob(f"{args.local_data_dir}/*.jpeg")
+ data = []
+ for image_path in image_paths:
+ img_hash = generate_image_hash(image_path)
+ data.append(
+ (img_hash, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds)
+ )
+
+ # Create a DataFrame
+ embedding_cols = [
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "pooled_prompt_embeds",
+ "negative_pooled_prompt_embeds",
+ ]
+ df = pd.DataFrame(
+ data,
+ columns=["image_hash"] + embedding_cols,
+ )
+
+ # Convert embedding lists to arrays (for proper storage in parquet)
+ for col in embedding_cols:
+ df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())
+
+ # Save the dataframe to a parquet file
+ df.to_parquet(args.output_path)
+ print(f"Data successfully serialized to {args.output_path}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--prompt", type=str, default=PROMPT, help="The instance prompt.")
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=MAX_SEQ_LENGTH,
+ help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.",
+ )
+ parser.add_argument(
+ "--local_data_dir", type=str, default=LOCAL_DATA_DIR, help="Path to the directory containing instance images."
+ )
+ parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.")
+ args = parser.parse_args()
+
+ run(args)
diff --git a/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb b/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb
new file mode 100644
index 000000000000..25fcb36a47d5
--- /dev/null
+++ b/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb
@@ -0,0 +1,2428 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "a6xLZDgOajbd"
+ },
+ "source": [
+ "# Running Stable Diffusion 3 (SD3) DreamBooth LoRA training under 16GB GPU VRAM"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "0jPZpMTwafua"
+ },
+ "source": [
+ "## Install Dependencies"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "lIYdn1woOS1n",
+ "outputId": "6d4a6332-d1f5-46e2-ad2b-c9e51b9f279a"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -q -U git+https://github.com/huggingface/diffusers\n",
+ "!pip install -q -U \\\n",
+ " transformers \\\n",
+ " accelerate \\\n",
+ " wandb \\\n",
+ " bitsandbytes \\\n",
+ " peft"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5qUNciw6aov2"
+ },
+ "source": [
+ "As SD3 is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youโve accepted the gate. Use the command below to log in:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Bpk5FleeK1NR",
+ "outputId": "54d8e774-514e-46fe-b9a7-0185e0bcf211"
+ },
+ "outputs": [],
+ "source": [
+ "!huggingface-cli login"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "tcF7gl4FasJV"
+ },
+ "source": [
+ "## Clone `diffusers`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "QgSOJYglJKiM",
+ "outputId": "be51f30f-8848-4a79-ae91-c4fb89c244ba"
+ },
+ "outputs": [],
+ "source": [
+ "!git clone https://github.com/huggingface/diffusers\n",
+ "%cd diffusers/examples/research_projects/sd3_lora_colab"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "X9dBawr6ayRY"
+ },
+ "source": [
+ "## Download instance data images"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 351,
+ "referenced_widgets": [
+ "8720a1f0a3b043dba02b6aab0afb861a",
+ "0e70a30146ef4b30b014179bd4bfd131",
+ "c39072b8cfff4a11ba283a9ae3155e52",
+ "1e834badd9c74b95bda30456a585fc06",
+ "d7c7c83b341b4471ad8a0ca1fe76d9ff",
+ "5ab639bd765f4824818a53ab84f690a8",
+ "cd94205b05d54e4c96c9c475e13abe83",
+ "be260274fdb04798af6fce6169646ff2",
+ "b9912757b9f9477186c171ecb2551d3a",
+ "a1f88f8e27894cdfab54cad04871f74e",
+ "19026c269dce47d585506f734fa2981a",
+ "50237341e55e4da0ba5cdbf652f30115",
+ "1d006f25b17e4cd8aaa5f66d58940dc7",
+ "4c673aa247ff4d65b82c4c64ca2e72da",
+ "92698388b667476ea1daf5cacb2fdb07",
+ "f6b3aa0f980e450289ee15cea7cb3ed7",
+ "0690a95eb8c3403e90d5b023aaadb22c",
+ "4a20ceca22724ab082d013c20c758d31",
+ "c80f3825981646a8a3e178192e338962",
+ "5673d0ca1f1247dd924874355eadecd4",
+ "7774ac850ab2451ea380bf80f3be5a86",
+ "22e57b8c83fa48489d6a327f1bbb756b",
+ "dd2debcf1c774181bef97efab0f3d6e1",
+ "633d7df9a17e4bf6951249aca83a9e96",
+ "6469e9991d7b41a0b83a7b443b9eebe5",
+ "0b9c72fa39c241ba9dd22dd67c2436fe",
+ "99e707cfe1454757aad4014230f6dae8",
+ "5a4ec2d031fa438eb4d0492329b28f00",
+ "6c0d4d6d84704f88b46a9b5cf94e3836",
+ "e1fb8dec23c04d6f8d1217242f8a495c",
+ "4b35f9d8d6444d0397a8bafcf3d73e8f",
+ "0f3279a4e6a247a7b69ff73bc06acfe0",
+ "b5ac4ab9256e4d5092ba6e449bc3cdd3",
+ "2840e90c518d4666b3f5a935c90569a7",
+ "adb012e95d7d442a820680e61e615e3c",
+ "be4fd10d940d49cf8e916904da8192ab",
+ "fd93adba791f46c1b0a25ff692426149",
+ "cdee3b61ca6a487c8ec8e7e884eb8b07",
+ "190a7fbe2b554104a6d5b2caa3b0a08e",
+ "975922b877e143edb09cdb888cb7cae8",
+ "d7365b62df59406dbd38677299cce1c8",
+ "67f0f5f1179140b4bdaa74c5583e3958",
+ "e560f25c3e334cf2a4c748981ac38da6",
+ "65173381a80b40748b7b2800fdb89151",
+ "7a4c5c0acd2d400e91da611e91ff5306",
+ "2a02c69a19a741b4a032dedbc21ad088",
+ "c6211ddb71e64f9e92c70158da2f7ef1",
+ "c219a1e791894469aa1452045b0f74b5",
+ "8e881fd3a17e4a5d95e71f6411ed8167",
+ "5350f001bf774b5fb7f3e362f912bec3",
+ "a893c93bcbc444a4931d1ddc6e342786",
+ "03047a13f06744fcac17c77cb03bca62",
+ "4d77a9c44d1c47b18022f8a51b29e20d",
+ "0b5bb94394fc447282d2c44780303f15",
+ "01bfa49325a8403b808ad1662465996e",
+ "3c0f67144f974aea85c7088f482e830d",
+ "0996e9f698dc4d6ab3c4319c96186619",
+ "9c663933ece544b193531725f4dc873d",
+ "b48dc06ca1654fe39bf8a77352a921f2",
+ "641f5a361a584cc0b71fd71d5f786958",
+ "66a952451b4c43cab54312fe886df5e6",
+ "42757924240c4abeb39add0c26687ab3",
+ "7f26ae5417cf4c80921cce830e60f72b",
+ "b77093ec9ffd40e0b2c1d9d1bdc063f5",
+ "7d8e3510c1e34524993849b8fce52758",
+ "b15011804460483ab84904a52da754b7"
+ ]
+ },
+ "id": "La1rBYWFNjEP",
+ "outputId": "e8567843-193e-4653-86b8-be26390700df"
+ },
+ "outputs": [],
+ "source": [
+ "from huggingface_hub import snapshot_download\n",
+ "\n",
+ "local_dir = \"./dog\"\n",
+ "snapshot_download(\n",
+ " \"diffusers/dog-example\",\n",
+ " local_dir=local_dir, repo_type=\"dataset\",\n",
+ " ignore_patterns=\".gitattributes\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "hbsIzdjbOzgi"
+ },
+ "outputs": [],
+ "source": [
+ "!rm -rf dog/.huggingface"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "88sOTn2ga07q"
+ },
+ "source": [
+ "## Compute embeddings\n",
+ "\n",
+ "Here we are using the default instance prompt \"a photo of sks dog\". But you can configure this. Refer to the `compute_embeddings.py` script for details on other supported arguments."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ha6hPLpHLM8c",
+ "outputId": "82843eb0-473e-4d6b-d11d-1f79bcd1d11a"
+ },
+ "outputs": [],
+ "source": [
+ "!python compute_embeddings.py"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "10iMo-RUa_yv"
+ },
+ "source": [
+ "## Clear memory"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-YltRmPgMuNa"
+ },
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import gc\n",
+ "\n",
+ "\n",
+ "def flush():\n",
+ " torch.cuda.empty_cache()\n",
+ " gc.collect()\n",
+ "\n",
+ "flush()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UO5oEtOJbBS9"
+ },
+ "source": [
+ "## Train!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "HuJ6hdm2M4Aw",
+ "outputId": "0b2d8ca3-c65f-4bb4-af9a-809b77116510"
+ },
+ "outputs": [],
+ "source": [
+ "!accelerate launch train_dreambooth_lora_sd3_miniature.py \\\n",
+ " --pretrained_model_name_or_path=\"stabilityai/stable-diffusion-3-medium-diffusers\" \\\n",
+ " --instance_data_dir=\"dog\" \\\n",
+ " --data_df_path=\"sample_embeddings.parquet\" \\\n",
+ " --output_dir=\"trained-sd3-lora-miniature\" \\\n",
+ " --mixed_precision=\"fp16\" \\\n",
+ " --instance_prompt=\"a photo of sks dog\" \\\n",
+ " --resolution=1024 \\\n",
+ " --train_batch_size=1 \\\n",
+ " --gradient_accumulation_steps=4 --gradient_checkpointing \\\n",
+ " --use_8bit_adam \\\n",
+ " --learning_rate=1e-4 \\\n",
+ " --report_to=\"wandb\" \\\n",
+ " --lr_scheduler=\"constant\" \\\n",
+ " --lr_warmup_steps=0 \\\n",
+ " --max_train_steps=500 \\\n",
+ " --seed=\"0\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "itS-dsJ0gjy3"
+ },
+ "source": [
+ "Training will take about an hour to complete depending on the length of your dataset."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "BpOuL7S1bI6j"
+ },
+ "source": [
+ "## Inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "clfMv4jKfQzb"
+ },
+ "outputs": [],
+ "source": [
+ "flush()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "np03SXHkbKpG"
+ },
+ "outputs": [],
+ "source": [
+ "from diffusers import DiffusionPipeline\n",
+ "import torch\n",
+ "\n",
+ "pipeline = DiffusionPipeline.from_pretrained(\n",
+ " \"stabilityai/stable-diffusion-3-medium-diffusers\",\n",
+ " torch_dtype=torch.float16\n",
+ ")\n",
+ "lora_output_path = \"trained-sd3-lora-miniature\"\n",
+ "pipeline.load_lora_weights(\"trained-sd3-lora-miniature\")\n",
+ "\n",
+ "pipeline.enable_sequential_cpu_offload()\n",
+ "\n",
+ "image = pipeline(\"a photo of sks dog in a bucket\").images[0]\n",
+ "image.save(\"bucket_dog.png\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "HDfrY2opjGjD"
+ },
+ "source": [
+ "Note that inference will be very slow in this case because we're loading and unloading individual components of the models and that introduces significant data movement overhead. Refer to [this resource](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more memory optimization related techniques."
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "01bfa49325a8403b808ad1662465996e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "03047a13f06744fcac17c77cb03bca62": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0690a95eb8c3403e90d5b023aaadb22c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0996e9f698dc4d6ab3c4319c96186619": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_66a952451b4c43cab54312fe886df5e6",
+ "placeholder": "โ",
+ "style": "IPY_MODEL_42757924240c4abeb39add0c26687ab3",
+ "value": "alvan-nee-bQaAJCbNq3g-unsplash.jpeg:โ100%"
+ }
+ },
+ "0b5bb94394fc447282d2c44780303f15": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0b9c72fa39c241ba9dd22dd67c2436fe": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_0f3279a4e6a247a7b69ff73bc06acfe0",
+ "placeholder": "โ",
+ "style": "IPY_MODEL_b5ac4ab9256e4d5092ba6e449bc3cdd3",
+ "value": "โ1.16M/1.16Mโ[00:00<00:00,โ10.3MB/s]"
+ }
+ },
+ "0e70a30146ef4b30b014179bd4bfd131": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_5ab639bd765f4824818a53ab84f690a8",
+ "placeholder": "โ",
+ "style": "IPY_MODEL_cd94205b05d54e4c96c9c475e13abe83",
+ "value": "Fetchingโ5โfiles:โ100%"
+ }
+ },
+ "0f3279a4e6a247a7b69ff73bc06acfe0": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "19026c269dce47d585506f734fa2981a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "190a7fbe2b554104a6d5b2caa3b0a08e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1d006f25b17e4cd8aaa5f66d58940dc7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_0690a95eb8c3403e90d5b023aaadb22c",
+ "placeholder": "โ",
+ "style": "IPY_MODEL_4a20ceca22724ab082d013c20c758d31",
+ "value": "alvan-nee-9M0tSjb-cpA-unsplash.jpeg:โ100%"
+ }
+ },
+ "1e834badd9c74b95bda30456a585fc06": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_a1f88f8e27894cdfab54cad04871f74e",
+ "placeholder": "โ",
+ "style": "IPY_MODEL_19026c269dce47d585506f734fa2981a",
+ "value": "โ5/5โ[00:01<00:00,โโ3.44it/s]"
+ }
+ },
+ "22e57b8c83fa48489d6a327f1bbb756b": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "2840e90c518d4666b3f5a935c90569a7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_adb012e95d7d442a820680e61e615e3c",
+ "IPY_MODEL_be4fd10d940d49cf8e916904da8192ab",
+ "IPY_MODEL_fd93adba791f46c1b0a25ff692426149"
+ ],
+ "layout": "IPY_MODEL_cdee3b61ca6a487c8ec8e7e884eb8b07"
+ }
+ },
+ "2a02c69a19a741b4a032dedbc21ad088": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_5350f001bf774b5fb7f3e362f912bec3",
+ "placeholder": "โ",
+ "style": "IPY_MODEL_a893c93bcbc444a4931d1ddc6e342786",
+ "value": "alvan-nee-eoqnr8ikwFE-unsplash.jpeg:โ100%"
+ }
+ },
+ "3c0f67144f974aea85c7088f482e830d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_0996e9f698dc4d6ab3c4319c96186619",
+ "IPY_MODEL_9c663933ece544b193531725f4dc873d",
+ "IPY_MODEL_b48dc06ca1654fe39bf8a77352a921f2"
+ ],
+ "layout": "IPY_MODEL_641f5a361a584cc0b71fd71d5f786958"
+ }
+ },
+ "42757924240c4abeb39add0c26687ab3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "4a20ceca22724ab082d013c20c758d31": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "4b35f9d8d6444d0397a8bafcf3d73e8f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "4c673aa247ff4d65b82c4c64ca2e72da": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_c80f3825981646a8a3e178192e338962",
+ "max": 677407,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_5673d0ca1f1247dd924874355eadecd4",
+ "value": 677407
+ }
+ },
+ "4d77a9c44d1c47b18022f8a51b29e20d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "50237341e55e4da0ba5cdbf652f30115": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_1d006f25b17e4cd8aaa5f66d58940dc7",
+ "IPY_MODEL_4c673aa247ff4d65b82c4c64ca2e72da",
+ "IPY_MODEL_92698388b667476ea1daf5cacb2fdb07"
+ ],
+ "layout": "IPY_MODEL_f6b3aa0f980e450289ee15cea7cb3ed7"
+ }
+ },
+ "5350f001bf774b5fb7f3e362f912bec3": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "5673d0ca1f1247dd924874355eadecd4": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "5a4ec2d031fa438eb4d0492329b28f00": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "5ab639bd765f4824818a53ab84f690a8": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "633d7df9a17e4bf6951249aca83a9e96": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_5a4ec2d031fa438eb4d0492329b28f00",
+ "placeholder": "โ",
+ "style": "IPY_MODEL_6c0d4d6d84704f88b46a9b5cf94e3836",
+ "value": "alvan-nee-Id1DBHv4fbg-unsplash.jpeg:โ100%"
+ }
+ },
+ "641f5a361a584cc0b71fd71d5f786958": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "6469e9991d7b41a0b83a7b443b9eebe5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e1fb8dec23c04d6f8d1217242f8a495c",
+ "max": 1163467,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_4b35f9d8d6444d0397a8bafcf3d73e8f",
+ "value": 1163467
+ }
+ },
+ "65173381a80b40748b7b2800fdb89151": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "66a952451b4c43cab54312fe886df5e6": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "67f0f5f1179140b4bdaa74c5583e3958": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "6c0d4d6d84704f88b46a9b5cf94e3836": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "7774ac850ab2451ea380bf80f3be5a86": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7a4c5c0acd2d400e91da611e91ff5306": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_2a02c69a19a741b4a032dedbc21ad088",
+ "IPY_MODEL_c6211ddb71e64f9e92c70158da2f7ef1",
+ "IPY_MODEL_c219a1e791894469aa1452045b0f74b5"
+ ],
+ "layout": "IPY_MODEL_8e881fd3a17e4a5d95e71f6411ed8167"
+ }
+ },
+ "7d8e3510c1e34524993849b8fce52758": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "7f26ae5417cf4c80921cce830e60f72b": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "8720a1f0a3b043dba02b6aab0afb861a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_0e70a30146ef4b30b014179bd4bfd131",
+ "IPY_MODEL_c39072b8cfff4a11ba283a9ae3155e52",
+ "IPY_MODEL_1e834badd9c74b95bda30456a585fc06"
+ ],
+ "layout": "IPY_MODEL_d7c7c83b341b4471ad8a0ca1fe76d9ff"
+ }
+ },
+ "8e881fd3a17e4a5d95e71f6411ed8167": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "92698388b667476ea1daf5cacb2fdb07": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_7774ac850ab2451ea380bf80f3be5a86",
+ "placeholder": "โ",
+ "style": "IPY_MODEL_22e57b8c83fa48489d6a327f1bbb756b",
+ "value": "โ677k/677kโ[00:00<00:00,โ9.50MB/s]"
+ }
+ },
+ "975922b877e143edb09cdb888cb7cae8": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "99e707cfe1454757aad4014230f6dae8": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9c663933ece544b193531725f4dc873d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_7f26ae5417cf4c80921cce830e60f72b",
+ "max": 1396297,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_b77093ec9ffd40e0b2c1d9d1bdc063f5",
+ "value": 1396297
+ }
+ },
+ "a1f88f8e27894cdfab54cad04871f74e": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "a893c93bcbc444a4931d1ddc6e342786": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "adb012e95d7d442a820680e61e615e3c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_190a7fbe2b554104a6d5b2caa3b0a08e",
+ "placeholder": "โ",
+ "style": "IPY_MODEL_975922b877e143edb09cdb888cb7cae8",
+ "value": "alvan-nee-brFsZ7qszSY-unsplash.jpeg:โ100%"
+ }
+ },
+ "b15011804460483ab84904a52da754b7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "b48dc06ca1654fe39bf8a77352a921f2": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_7d8e3510c1e34524993849b8fce52758",
+ "placeholder": "โ",
+ "style": "IPY_MODEL_b15011804460483ab84904a52da754b7",
+ "value": "โ1.40M/1.40Mโ[00:00<00:00,โ10.5MB/s]"
+ }
+ },
+ "b5ac4ab9256e4d5092ba6e449bc3cdd3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "b77093ec9ffd40e0b2c1d9d1bdc063f5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "b9912757b9f9477186c171ecb2551d3a": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "be260274fdb04798af6fce6169646ff2": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "be4fd10d940d49cf8e916904da8192ab": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_d7365b62df59406dbd38677299cce1c8",
+ "max": 1186464,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_67f0f5f1179140b4bdaa74c5583e3958",
+ "value": 1186464
+ }
+ },
+ "c219a1e791894469aa1452045b0f74b5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_0b5bb94394fc447282d2c44780303f15",
+ "placeholder": "โ",
+ "style": "IPY_MODEL_01bfa49325a8403b808ad1662465996e",
+ "value": "โ1.17M/1.17Mโ[00:00<00:00,โ15.7MB/s]"
+ }
+ },
+ "c39072b8cfff4a11ba283a9ae3155e52": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_be260274fdb04798af6fce6169646ff2",
+ "max": 5,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_b9912757b9f9477186c171ecb2551d3a",
+ "value": 5
+ }
+ },
+ "c6211ddb71e64f9e92c70158da2f7ef1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_03047a13f06744fcac17c77cb03bca62",
+ "max": 1167042,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_4d77a9c44d1c47b18022f8a51b29e20d",
+ "value": 1167042
+ }
+ },
+ "c80f3825981646a8a3e178192e338962": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "cd94205b05d54e4c96c9c475e13abe83": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "cdee3b61ca6a487c8ec8e7e884eb8b07": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d7365b62df59406dbd38677299cce1c8": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "d7c7c83b341b4471ad8a0ca1fe76d9ff": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "dd2debcf1c774181bef97efab0f3d6e1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_633d7df9a17e4bf6951249aca83a9e96",
+ "IPY_MODEL_6469e9991d7b41a0b83a7b443b9eebe5",
+ "IPY_MODEL_0b9c72fa39c241ba9dd22dd67c2436fe"
+ ],
+ "layout": "IPY_MODEL_99e707cfe1454757aad4014230f6dae8"
+ }
+ },
+ "e1fb8dec23c04d6f8d1217242f8a495c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e560f25c3e334cf2a4c748981ac38da6": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "f6b3aa0f980e450289ee15cea7cb3ed7": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "fd93adba791f46c1b0a25ff692426149": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_e560f25c3e334cf2a4c748981ac38da6",
+ "placeholder": "โ",
+ "style": "IPY_MODEL_65173381a80b40748b7b2800fdb89151",
+ "value": "โ1.19M/1.19Mโ[00:00<00:00,โ12.7MB/s]"
+ }
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
new file mode 100644
index 000000000000..163ff8f08931
--- /dev/null
+++ b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
@@ -0,0 +1,1147 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import copy
+import gc
+import hashlib
+import logging
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import torch
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ SD3Transformer2DModel,
+ StableDiffusion3Pipeline,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ cast_training_params,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+)
+from diffusers.utils import (
+ check_min_version,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.30.0.dev0")
+
+logger = get_logger(__name__)
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ train_text_encoder=False,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# SD3 DreamBooth LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/).
+
+LoRA for the text encoder was enabled: {train_text_encoder}.
+
+## Trigger words
+
+You should use {instance_prompt} to trigger the image generation.
+
+## Download model
+
+[Download]({repo_id}/tree/main) them in the Files & versions tab.
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="openrail++",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "sd3",
+ "sd3-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline.enable_model_cpu_offload()
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
+ # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
+ autocast_ctx = nullcontext()
+
+ with autocast_ctx:
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return images
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+ parser.add_argument(
+ "--data_df_path",
+ type=str,
+ default=None,
+ help=("Path to the parquet file serialized with compute_embeddings.py."),
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=77,
+ help="Maximum sequence length to use with with the T5 text encoder",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="sd3-dreambooth-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="logit_normal",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer.",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.instance_data_dir is None:
+ raise ValueError("Specify `instance_data_dir`.")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ data_df_path,
+ instance_data_root,
+ instance_prompt,
+ size=1024,
+ center_crop=False,
+ ):
+ # Logistics
+ self.size = size
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ # Load images.
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ image_hashes = [self.generate_image_hash(path) for path in list(Path(instance_data_root).iterdir())]
+ self.instance_images = instance_images
+ self.image_hashes = image_hashes
+
+ # Image transformations
+ self.pixel_values = self.apply_image_transformations(
+ instance_images=instance_images, size=size, center_crop=center_crop
+ )
+
+ # Map hashes to embeddings.
+ self.data_dict = self.map_image_hash_embedding(data_df_path=data_df_path)
+
+ self.num_instance_images = len(instance_images)
+ self._length = self.num_instance_images
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image = self.pixel_values[index % self.num_instance_images]
+ image_hash = self.image_hashes[index % self.num_instance_images]
+ prompt_embeds, pooled_prompt_embeds = self.data_dict[image_hash]
+ example["instance_images"] = instance_image
+ example["prompt_embeds"] = prompt_embeds
+ example["pooled_prompt_embeds"] = pooled_prompt_embeds
+ return example
+
+ def apply_image_transformations(self, instance_images, size, center_crop):
+ pixel_values = []
+
+ train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
+ train_transforms = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+ for image in instance_images:
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = train_resize(image)
+ if args.random_flip and random.random() < 0.5:
+ # flip
+ image = train_flip(image)
+ if args.center_crop:
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ image = train_crop(image)
+ else:
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ image = crop(image, y1, x1, h, w)
+ image = train_transforms(image)
+ pixel_values.append(image)
+
+ return pixel_values
+
+ def convert_to_torch_tensor(self, embeddings: list):
+ prompt_embeds = embeddings[0]
+ pooled_prompt_embeds = embeddings[1]
+ prompt_embeds = np.array(prompt_embeds).reshape(154, 4096)
+ pooled_prompt_embeds = np.array(pooled_prompt_embeds).reshape(2048)
+ return torch.from_numpy(prompt_embeds), torch.from_numpy(pooled_prompt_embeds)
+
+ def map_image_hash_embedding(self, data_df_path):
+ hashes_df = pd.read_parquet(data_df_path)
+ data_dict = {}
+ for i, row in hashes_df.iterrows():
+ embeddings = [row["prompt_embeds"], row["pooled_prompt_embeds"]]
+ prompt_embeds, pooled_prompt_embeds = self.convert_to_torch_tensor(embeddings=embeddings)
+ data_dict.update({row["image_hash"]: (prompt_embeds, pooled_prompt_embeds)})
+ return data_dict
+
+ def generate_image_hash(self, image_path):
+ with open(image_path, "rb") as f:
+ img_data = f.read()
+ return hashlib.sha256(img_data).hexdigest()
+
+
+def collate_fn(examples):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompt_embeds = [example["prompt_embeds"] for example in examples]
+ pooled_prompt_embeds = [example["pooled_prompt_embeds"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+ prompt_embeds = torch.stack(prompt_embeds)
+ pooled_prompt_embeds = torch.stack(pooled_prompt_embeds)
+
+ batch = {
+ "pixel_values": pixel_values,
+ "prompt_embeds": prompt_embeds,
+ "pooled_prompt_embeds": pooled_prompt_embeds,
+ }
+ return batch
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ transformer = SD3Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
+ )
+
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ vae.to(accelerator.device, dtype=torch.float32)
+ transformer.to(accelerator.device, dtype=weight_dtype)
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+
+ # now we will add new LoRA weights to the attention layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+ transformer.add_adapter(transformer_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+ for model in models:
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ StableDiffusion3Pipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ # Optimization parameters
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include [adamW]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ data_df_path=args.data_df_path,
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ size=args.resolution,
+ center_crop=args.center_crop,
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ shuffle=True,
+ collate_fn=lambda examples: collate_fn(examples),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-sd3-lora-miniature"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ with accelerator.accumulate(models_to_accumulate):
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+
+ # Convert images to latent space
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ model_input = model_input.to(dtype=weight_dtype)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
+
+ # Predict the noise residual
+ prompt_embeds, pooled_prompt_embeds = batch["prompt_embeds"], batch["pooled_prompt_embeds"]
+ prompt_embeds = prompt_embeds.to(device=accelerator.device, dtype=weight_dtype)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(device=accelerator.device, dtype=weight_dtype)
+ model_pred = transformer(
+ hidden_states=noisy_model_input,
+ timestep=timesteps,
+ encoder_hidden_states=prompt_embeds,
+ pooled_projections=pooled_prompt_embeds,
+ return_dict=False,
+ )[0]
+
+ # Follow: Section 5 of https://arxiv.org/abs/2206.00364.
+ # Preconditioning of the model outputs.
+ model_pred = model_pred * (-sigmas) + noisy_model_input
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ target = model_input
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = transformer_lora_parameters
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ transformer=accelerator.unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline_args = {"prompt": args.validation_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ )
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ transformer = unwrap_model(transformer)
+ transformer = transformer.to(torch.float32)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+
+ StableDiffusion3Pipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ )
+
+ # Final inference
+ # Load previous pipeline
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {"prompt": args.validation_prompt}
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ is_final_validation=True,
+ )
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/research_projects/vae/README.md b/examples/research_projects/vae/README.md
new file mode 100644
index 000000000000..2e24c955b7ae
--- /dev/null
+++ b/examples/research_projects/vae/README.md
@@ -0,0 +1,11 @@
+# VAE
+
+`vae_roundtrip.py` Demonstrates the use of a VAE by roundtripping an image through the encoder and decoder. Original and reconstructed images are displayed side by side.
+
+```
+cd examples/research_projects/vae
+python vae_roundtrip.py \
+ --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
+ --subfolder="vae" \
+ --input_image="/path/to/your/input.png"
+```
diff --git a/examples/research_projects/vae/vae_roundtrip.py b/examples/research_projects/vae/vae_roundtrip.py
new file mode 100644
index 000000000000..65c2b43a9bde
--- /dev/null
+++ b/examples/research_projects/vae/vae_roundtrip.py
@@ -0,0 +1,282 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import typing
+from typing import Optional, Union
+
+import torch
+from PIL import Image
+from torchvision import transforms # type: ignore
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.models.autoencoders.autoencoder_kl import (
+ AutoencoderKL,
+ AutoencoderKLOutput,
+)
+from diffusers.models.autoencoders.autoencoder_tiny import (
+ AutoencoderTiny,
+ AutoencoderTinyOutput,
+)
+from diffusers.models.autoencoders.vae import DecoderOutput
+
+
+SupportedAutoencoder = Union[AutoencoderKL, AutoencoderTiny]
+
+
+def load_vae_model(
+ *,
+ device: torch.device,
+ model_name_or_path: str,
+ revision: Optional[str],
+ variant: Optional[str],
+ # NOTE: use subfolder="vae" if the pointed model is for stable diffusion as a whole instead of just the VAE
+ subfolder: Optional[str],
+ use_tiny_nn: bool,
+) -> SupportedAutoencoder:
+ if use_tiny_nn:
+ # NOTE: These scaling factors don't have to be the same as each other.
+ down_scale = 2
+ up_scale = 2
+ vae = AutoencoderTiny.from_pretrained( # type: ignore
+ model_name_or_path,
+ subfolder=subfolder,
+ revision=revision,
+ variant=variant,
+ downscaling_scaling_factor=down_scale,
+ upsampling_scaling_factor=up_scale,
+ )
+ assert isinstance(vae, AutoencoderTiny)
+ else:
+ vae = AutoencoderKL.from_pretrained( # type: ignore
+ model_name_or_path,
+ subfolder=subfolder,
+ revision=revision,
+ variant=variant,
+ )
+ assert isinstance(vae, AutoencoderKL)
+ vae = vae.to(device)
+ vae.eval() # Set the model to inference mode
+ return vae
+
+
+def pil_to_nhwc(
+ *,
+ device: torch.device,
+ image: Image.Image,
+) -> torch.Tensor:
+ assert image.mode == "RGB"
+ transform = transforms.ToTensor()
+ nhwc = transform(image).unsqueeze(0).to(device) # type: ignore
+ assert isinstance(nhwc, torch.Tensor)
+ return nhwc
+
+
+def nhwc_to_pil(
+ *,
+ nhwc: torch.Tensor,
+) -> Image.Image:
+ assert nhwc.shape[0] == 1
+ hwc = nhwc.squeeze(0).cpu()
+ return transforms.ToPILImage()(hwc) # type: ignore
+
+
+def concatenate_images(
+ *,
+ left: Image.Image,
+ right: Image.Image,
+ vertical: bool = False,
+) -> Image.Image:
+ width1, height1 = left.size
+ width2, height2 = right.size
+ if vertical:
+ total_height = height1 + height2
+ max_width = max(width1, width2)
+ new_image = Image.new("RGB", (max_width, total_height))
+ new_image.paste(left, (0, 0))
+ new_image.paste(right, (0, height1))
+ else:
+ total_width = width1 + width2
+ max_height = max(height1, height2)
+ new_image = Image.new("RGB", (total_width, max_height))
+ new_image.paste(left, (0, 0))
+ new_image.paste(right, (width1, 0))
+ return new_image
+
+
+def to_latent(
+ *,
+ rgb_nchw: torch.Tensor,
+ vae: SupportedAutoencoder,
+) -> torch.Tensor:
+ rgb_nchw = VaeImageProcessor.normalize(rgb_nchw) # type: ignore
+ encoding_nchw = vae.encode(typing.cast(torch.FloatTensor, rgb_nchw))
+ if isinstance(encoding_nchw, AutoencoderKLOutput):
+ latent = encoding_nchw.latent_dist.sample() # type: ignore
+ assert isinstance(latent, torch.Tensor)
+ elif isinstance(encoding_nchw, AutoencoderTinyOutput):
+ latent = encoding_nchw.latents
+ do_internal_vae_scaling = False # Is this needed?
+ if do_internal_vae_scaling:
+ latent = vae.scale_latents(latent).mul(255).round().byte() # type: ignore
+ latent = vae.unscale_latents(latent / 255.0) # type: ignore
+ assert isinstance(latent, torch.Tensor)
+ else:
+ assert False, f"Unknown encoding type: {type(encoding_nchw)}"
+ return latent
+
+
+def from_latent(
+ *,
+ latent_nchw: torch.Tensor,
+ vae: SupportedAutoencoder,
+) -> torch.Tensor:
+ decoding_nchw = vae.decode(latent_nchw) # type: ignore
+ assert isinstance(decoding_nchw, DecoderOutput)
+ rgb_nchw = VaeImageProcessor.denormalize(decoding_nchw.sample) # type: ignore
+ assert isinstance(rgb_nchw, torch.Tensor)
+ return rgb_nchw
+
+
+def main_kwargs(
+ *,
+ device: torch.device,
+ input_image_path: str,
+ pretrained_model_name_or_path: str,
+ revision: Optional[str],
+ variant: Optional[str],
+ subfolder: Optional[str],
+ use_tiny_nn: bool,
+) -> None:
+ vae = load_vae_model(
+ device=device,
+ model_name_or_path=pretrained_model_name_or_path,
+ revision=revision,
+ variant=variant,
+ subfolder=subfolder,
+ use_tiny_nn=use_tiny_nn,
+ )
+ original_pil = Image.open(input_image_path).convert("RGB")
+ original_image = pil_to_nhwc(
+ device=device,
+ image=original_pil,
+ )
+ print(f"Original image shape: {original_image.shape}")
+ reconstructed_image: Optional[torch.Tensor] = None
+
+ with torch.no_grad():
+ latent_image = to_latent(rgb_nchw=original_image, vae=vae)
+ print(f"Latent shape: {latent_image.shape}")
+ reconstructed_image = from_latent(latent_nchw=latent_image, vae=vae)
+ reconstructed_pil = nhwc_to_pil(nhwc=reconstructed_image)
+ combined_image = concatenate_images(
+ left=original_pil,
+ right=reconstructed_pil,
+ vertical=False,
+ )
+ combined_image.show("Original | Reconstruction")
+ print(f"Reconstructed image shape: {reconstructed_image.shape}")
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description="Inference with VAE")
+ parser.add_argument(
+ "--input_image",
+ type=str,
+ required=True,
+ help="Path to the input image for inference.",
+ )
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ required=True,
+ help="Path to pretrained VAE model.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ help="Model version.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Model file variant, e.g., 'fp16'.",
+ )
+ parser.add_argument(
+ "--subfolder",
+ type=str,
+ default=None,
+ help="Subfolder in the model file.",
+ )
+ parser.add_argument(
+ "--use_cuda",
+ action="store_true",
+ help="Use CUDA if available.",
+ )
+ parser.add_argument(
+ "--use_tiny_nn",
+ action="store_true",
+ help="Use tiny neural network.",
+ )
+ return parser.parse_args()
+
+
+# EXAMPLE USAGE:
+#
+# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "runwayml/stable-diffusion-v1-5" --subfolder "vae" --input_image "foo.png"
+#
+# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "madebyollin/taesd" --use_tiny_nn --input_image "foo.png"
+#
+def main_cli() -> None:
+ args = parse_args()
+
+ input_image_path = args.input_image
+ assert isinstance(input_image_path, str)
+
+ pretrained_model_name_or_path = args.pretrained_model_name_or_path
+ assert isinstance(pretrained_model_name_or_path, str)
+
+ revision = args.revision
+ assert isinstance(revision, (str, type(None)))
+
+ variant = args.variant
+ assert isinstance(variant, (str, type(None)))
+
+ subfolder = args.subfolder
+ assert isinstance(subfolder, (str, type(None)))
+
+ use_cuda = args.use_cuda
+ assert isinstance(use_cuda, bool)
+
+ use_tiny_nn = args.use_tiny_nn
+ assert isinstance(use_tiny_nn, bool)
+
+ device = torch.device("cuda" if use_cuda else "cpu")
+
+ main_kwargs(
+ device=device,
+ input_image_path=input_image_path,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ revision=revision,
+ variant=variant,
+ subfolder=subfolder,
+ use_tiny_nn=use_tiny_nn,
+ )
+
+
+if __name__ == "__main__":
+ main_cli()
diff --git a/scripts/convert_hunyuandit_controlnet_to_diffusers.py b/scripts/convert_hunyuandit_controlnet_to_diffusers.py
new file mode 100644
index 000000000000..1c8383690890
--- /dev/null
+++ b/scripts/convert_hunyuandit_controlnet_to_diffusers.py
@@ -0,0 +1,241 @@
+import argparse
+
+import torch
+
+from diffusers import HunyuanDiT2DControlNetModel
+
+
+def main(args):
+ state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
+
+ if args.load_key != "none":
+ try:
+ state_dict = state_dict[args.load_key]
+ except KeyError:
+ raise KeyError(
+ f"{args.load_key} not found in the checkpoint."
+ "Please load from the following keys:{state_dict.keys()}"
+ )
+ device = "cuda"
+
+ model_config = HunyuanDiT2DControlNetModel.load_config(
+ "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
+ )
+ model_config[
+ "use_style_cond_and_image_meta_size"
+ ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
+ print(model_config)
+
+ for key in state_dict:
+ print("local:", key)
+
+ model = HunyuanDiT2DControlNetModel.from_config(model_config).to(device)
+
+ for key in model.state_dict():
+ print("diffusers:", key)
+
+ num_layers = 19
+ for i in range(num_layers):
+ # attn1
+ # Wkqv -> to_q, to_k, to_v
+ q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
+ q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
+ state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
+ state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
+ state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
+ state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
+ state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
+ state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
+
+ # q_norm, k_norm -> norm_q, norm_k
+ state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
+ state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
+ state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
+ state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
+
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
+
+ # out_proj -> to_out
+ state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
+ state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
+
+ # attn2
+ # kq_proj -> to_k, to_v
+ k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
+ k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
+ state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
+ state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
+ state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
+ state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
+
+ # q_proj -> to_q
+ state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
+ state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
+
+ # q_norm, k_norm -> norm_q, norm_k
+ state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
+ state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
+ state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
+ state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
+
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
+
+ # out_proj -> to_out
+ state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
+ state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
+
+ # switch norm 2 and norm 3
+ norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
+ norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
+ state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
+ state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
+ state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
+ state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
+
+ # norm1 -> norm1.norm
+ # default_modulation.1 -> norm1.linear
+ state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]
+ state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]
+ state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]
+ state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]
+ state_dict.pop(f"blocks.{i}.norm1.weight")
+ state_dict.pop(f"blocks.{i}.norm1.bias")
+ state_dict.pop(f"blocks.{i}.default_modulation.1.weight")
+ state_dict.pop(f"blocks.{i}.default_modulation.1.bias")
+
+ # mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2
+ state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]
+ state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]
+ state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]
+ state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]
+ state_dict.pop(f"blocks.{i}.mlp.fc1.weight")
+ state_dict.pop(f"blocks.{i}.mlp.fc1.bias")
+ state_dict.pop(f"blocks.{i}.mlp.fc2.weight")
+ state_dict.pop(f"blocks.{i}.mlp.fc2.bias")
+
+ # after_proj_list -> controlnet_blocks
+ state_dict[f"controlnet_blocks.{i}.weight"] = state_dict[f"after_proj_list.{i}.weight"]
+ state_dict[f"controlnet_blocks.{i}.bias"] = state_dict[f"after_proj_list.{i}.bias"]
+ state_dict.pop(f"after_proj_list.{i}.weight")
+ state_dict.pop(f"after_proj_list.{i}.bias")
+
+ # before_proj -> input_block
+ state_dict["input_block.weight"] = state_dict["before_proj.weight"]
+ state_dict["input_block.bias"] = state_dict["before_proj.bias"]
+ state_dict.pop("before_proj.weight")
+ state_dict.pop("before_proj.bias")
+
+ # pooler -> time_extra_emb
+ state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
+ state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
+ state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
+ state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
+ state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
+ state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
+ state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
+ state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
+ state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
+ state_dict.pop("pooler.k_proj.weight")
+ state_dict.pop("pooler.k_proj.bias")
+ state_dict.pop("pooler.q_proj.weight")
+ state_dict.pop("pooler.q_proj.bias")
+ state_dict.pop("pooler.v_proj.weight")
+ state_dict.pop("pooler.v_proj.bias")
+ state_dict.pop("pooler.c_proj.weight")
+ state_dict.pop("pooler.c_proj.bias")
+ state_dict.pop("pooler.positional_embedding")
+
+ # t_embedder -> time_embedding (`TimestepEmbedding`)
+ state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
+ state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
+ state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
+ state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
+
+ state_dict.pop("t_embedder.mlp.0.bias")
+ state_dict.pop("t_embedder.mlp.0.weight")
+ state_dict.pop("t_embedder.mlp.2.bias")
+ state_dict.pop("t_embedder.mlp.2.weight")
+
+ # x_embedder -> pos_embd (`PatchEmbed`)
+ state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
+ state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
+ state_dict.pop("x_embedder.proj.weight")
+ state_dict.pop("x_embedder.proj.bias")
+
+ # mlp_t5 -> text_embedder
+ state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]
+ state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]
+ state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]
+ state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]
+ state_dict.pop("mlp_t5.0.bias")
+ state_dict.pop("mlp_t5.0.weight")
+ state_dict.pop("mlp_t5.2.bias")
+ state_dict.pop("mlp_t5.2.weight")
+
+ # extra_embedder -> extra_embedder
+ state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]
+ state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]
+ state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]
+ state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]
+ state_dict.pop("extra_embedder.0.bias")
+ state_dict.pop("extra_embedder.0.weight")
+ state_dict.pop("extra_embedder.2.bias")
+ state_dict.pop("extra_embedder.2.weight")
+
+ # style_embedder
+ if model_config["use_style_cond_and_image_meta_size"]:
+ print(state_dict["style_embedder.weight"])
+ print(state_dict["style_embedder.weight"].shape)
+ state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]
+ state_dict.pop("style_embedder.weight")
+
+ model.load_state_dict(state_dict)
+
+ if args.save:
+ model.save_pretrained(args.output_checkpoint_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
+ )
+ parser.add_argument(
+ "--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."
+ )
+ parser.add_argument(
+ "--output_checkpoint_path",
+ default=None,
+ type=str,
+ required=False,
+ help="Path to the output converted diffusers pipeline.",
+ )
+ parser.add_argument(
+ "--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"
+ )
+ parser.add_argument(
+ "--use_style_cond_and_image_meta_size",
+ type=bool,
+ default=False,
+ help="version <= v1.1: True; version >= v1.2: False",
+ )
+
+ args = parser.parse_args()
+ main(args)
diff --git a/scripts/convert_hunyuandit_to_diffusers.py b/scripts/convert_hunyuandit_to_diffusers.py
new file mode 100644
index 000000000000..da3af8333ee3
--- /dev/null
+++ b/scripts/convert_hunyuandit_to_diffusers.py
@@ -0,0 +1,267 @@
+import argparse
+
+import torch
+
+from diffusers import HunyuanDiT2DModel
+
+
+def main(args):
+ state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
+
+ if args.load_key != "none":
+ try:
+ state_dict = state_dict[args.load_key]
+ except KeyError:
+ raise KeyError(
+ f"{args.load_key} not found in the checkpoint."
+ f"Please load from the following keys:{state_dict.keys()}"
+ )
+
+ device = "cuda"
+ model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
+ model_config[
+ "use_style_cond_and_image_meta_size"
+ ] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
+
+ # input_size -> sample_size, text_dim -> cross_attention_dim
+ for key in state_dict:
+ print("local:", key)
+
+ model = HunyuanDiT2DModel.from_config(model_config).to(device)
+
+ for key in model.state_dict():
+ print("diffusers:", key)
+
+ num_layers = 40
+ for i in range(num_layers):
+ # attn1
+ # Wkqv -> to_q, to_k, to_v
+ q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
+ q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
+ state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
+ state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
+ state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
+ state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
+ state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
+ state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
+ state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
+
+ # q_norm, k_norm -> norm_q, norm_k
+ state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
+ state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
+ state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
+ state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
+
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
+
+ # out_proj -> to_out
+ state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
+ state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
+
+ # attn2
+ # kq_proj -> to_k, to_v
+ k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
+ k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
+ state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
+ state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
+ state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
+ state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
+
+ # q_proj -> to_q
+ state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
+ state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
+
+ # q_norm, k_norm -> norm_q, norm_k
+ state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
+ state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
+ state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
+ state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
+
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
+ state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
+
+ # out_proj -> to_out
+ state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
+ state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
+ state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
+
+ # switch norm 2 and norm 3
+ norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
+ norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
+ state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
+ state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
+ state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
+ state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
+
+ # norm1 -> norm1.norm
+ # default_modulation.1 -> norm1.linear
+ state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]
+ state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]
+ state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]
+ state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]
+ state_dict.pop(f"blocks.{i}.norm1.weight")
+ state_dict.pop(f"blocks.{i}.norm1.bias")
+ state_dict.pop(f"blocks.{i}.default_modulation.1.weight")
+ state_dict.pop(f"blocks.{i}.default_modulation.1.bias")
+
+ # mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2
+ state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]
+ state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]
+ state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]
+ state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]
+ state_dict.pop(f"blocks.{i}.mlp.fc1.weight")
+ state_dict.pop(f"blocks.{i}.mlp.fc1.bias")
+ state_dict.pop(f"blocks.{i}.mlp.fc2.weight")
+ state_dict.pop(f"blocks.{i}.mlp.fc2.bias")
+
+ # pooler -> time_extra_emb
+ state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
+ state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
+ state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
+ state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
+ state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
+ state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
+ state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
+ state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
+ state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
+ state_dict.pop("pooler.k_proj.weight")
+ state_dict.pop("pooler.k_proj.bias")
+ state_dict.pop("pooler.q_proj.weight")
+ state_dict.pop("pooler.q_proj.bias")
+ state_dict.pop("pooler.v_proj.weight")
+ state_dict.pop("pooler.v_proj.bias")
+ state_dict.pop("pooler.c_proj.weight")
+ state_dict.pop("pooler.c_proj.bias")
+ state_dict.pop("pooler.positional_embedding")
+
+ # t_embedder -> time_embedding (`TimestepEmbedding`)
+ state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
+ state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
+ state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
+ state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
+
+ state_dict.pop("t_embedder.mlp.0.bias")
+ state_dict.pop("t_embedder.mlp.0.weight")
+ state_dict.pop("t_embedder.mlp.2.bias")
+ state_dict.pop("t_embedder.mlp.2.weight")
+
+ # x_embedder -> pos_embd (`PatchEmbed`)
+ state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
+ state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
+ state_dict.pop("x_embedder.proj.weight")
+ state_dict.pop("x_embedder.proj.bias")
+
+ # mlp_t5 -> text_embedder
+ state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]
+ state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]
+ state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]
+ state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]
+ state_dict.pop("mlp_t5.0.bias")
+ state_dict.pop("mlp_t5.0.weight")
+ state_dict.pop("mlp_t5.2.bias")
+ state_dict.pop("mlp_t5.2.weight")
+
+ # extra_embedder -> extra_embedder
+ state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]
+ state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]
+ state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]
+ state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]
+ state_dict.pop("extra_embedder.0.bias")
+ state_dict.pop("extra_embedder.0.weight")
+ state_dict.pop("extra_embedder.2.bias")
+ state_dict.pop("extra_embedder.2.weight")
+
+ # model.final_adaLN_modulation.1 -> norm_out.linear
+ def swap_scale_shift(weight):
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ return new_weight
+
+ state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.weight"])
+ state_dict["norm_out.linear.bias"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.bias"])
+ state_dict.pop("final_layer.adaLN_modulation.1.weight")
+ state_dict.pop("final_layer.adaLN_modulation.1.bias")
+
+ # final_linear -> proj_out
+ state_dict["proj_out.weight"] = state_dict["final_layer.linear.weight"]
+ state_dict["proj_out.bias"] = state_dict["final_layer.linear.bias"]
+ state_dict.pop("final_layer.linear.weight")
+ state_dict.pop("final_layer.linear.bias")
+
+ # style_embedder
+ if model_config["use_style_cond_and_image_meta_size"]:
+ print(state_dict["style_embedder.weight"])
+ print(state_dict["style_embedder.weight"].shape)
+ state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]
+ state_dict.pop("style_embedder.weight")
+
+ model.load_state_dict(state_dict)
+
+ from diffusers import HunyuanDiTPipeline
+
+ if args.use_style_cond_and_image_meta_size:
+ pipe = HunyuanDiTPipeline.from_pretrained(
+ "Tencent-Hunyuan/HunyuanDiT-Diffusers", transformer=model, torch_dtype=torch.float32
+ )
+ else:
+ pipe = HunyuanDiTPipeline.from_pretrained(
+ "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", transformer=model, torch_dtype=torch.float32
+ )
+ pipe.to("cuda")
+ pipe.to(dtype=torch.float32)
+
+ if args.save:
+ pipe.save_pretrained(args.output_checkpoint_path)
+
+ # ### NOTE: HunyuanDiT supports both Chinese and English inputs
+ prompt = "ไธไธชๅฎ่ชๅๅจ้ช้ฉฌ"
+ # prompt = "An astronaut riding a horse"
+ generator = torch.Generator(device="cuda").manual_seed(0)
+ image = pipe(
+ height=1024, width=1024, prompt=prompt, generator=generator, num_inference_steps=25, guidance_scale=5.0
+ ).images[0]
+
+ image.save("img.png")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
+ )
+ parser.add_argument(
+ "--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."
+ )
+ parser.add_argument(
+ "--output_checkpoint_path",
+ default=None,
+ type=str,
+ required=False,
+ help="Path to the output converted diffusers pipeline.",
+ )
+ parser.add_argument(
+ "--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"
+ )
+ parser.add_argument(
+ "--use_style_cond_and_image_meta_size",
+ type=bool,
+ default=False,
+ help="version <= v1.1: True; version >= v1.2: False",
+ )
+
+ args = parser.parse_args()
+ main(args)
diff --git a/src/diffusers/loaders/autoencoder.py b/src/diffusers/loaders/autoencoder.py
deleted file mode 100644
index 36b022a26ec9..000000000000
--- a/src/diffusers/loaders/autoencoder.py
+++ /dev/null
@@ -1,146 +0,0 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from huggingface_hub.utils import validate_hf_hub_args
-
-from .single_file_utils import (
- create_diffusers_vae_model_from_ldm,
- fetch_ldm_config_and_checkpoint,
-)
-
-
-class FromOriginalVAEMixin:
- """
- Load pretrained AutoencoderKL weights saved in the `.ckpt` or `.safetensors` format into a [`AutoencoderKL`].
- """
-
- @classmethod
- @validate_hf_hub_args
- def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
- r"""
- Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
- `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
-
- Parameters:
- pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
- Can be either:
- - A link to the `.ckpt` file (for example
- `"https://huggingface.co//blob/main/.ckpt"`) on the Hub.
- - A path to a *file* containing all pipeline weights.
- config_file (`str`, *optional*):
- Filepath to the configuration YAML file associated with the model. If not provided it will default to:
- https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
- torch_dtype (`str` or `torch.dtype`, *optional*):
- Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
- dtype is automatically derived from the model's weights.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- resume_download:
- Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
- of Diffusers.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to True, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- image_size (`int`, *optional*, defaults to 512):
- The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
- Diffusion v2 base model. Use 768 for Stable Diffusion v2.
- scaling_factor (`float`, *optional*, defaults to 0.18215):
- The component-wise standard deviation of the trained latent space computed using the first batch of the
- training set. This is used to scale the latent space to have unit variance when training the diffusion
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
- = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
- Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
- kwargs (remaining dictionary of keyword arguments, *optional*):
- Can be used to overwrite load and saveable variables (for example the pipeline components of the
- specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
- method. See example below for more information.
-
-
-
- Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
- a VAE from SDXL or a Stable Diffusion v2 model or higher.
-
-
-
- Examples:
-
- ```py
- from diffusers import AutoencoderKL
-
- url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
- model = AutoencoderKL.from_single_file(url)
- ```
- """
-
- original_config_file = kwargs.pop("original_config_file", None)
- config_file = kwargs.pop("config_file", None)
- resume_download = kwargs.pop("resume_download", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- token = kwargs.pop("token", None)
- cache_dir = kwargs.pop("cache_dir", None)
- local_files_only = kwargs.pop("local_files_only", None)
- revision = kwargs.pop("revision", None)
- torch_dtype = kwargs.pop("torch_dtype", None)
-
- class_name = cls.__name__
-
- if (config_file is not None) and (original_config_file is not None):
- raise ValueError(
- "You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments."
- )
-
- original_config_file = original_config_file or config_file
- original_config, checkpoint = fetch_ldm_config_and_checkpoint(
- pretrained_model_link_or_path=pretrained_model_link_or_path,
- class_name=class_name,
- original_config_file=original_config_file,
- resume_download=resume_download,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- )
-
- image_size = kwargs.pop("image_size", None)
- scaling_factor = kwargs.pop("scaling_factor", None)
- component = create_diffusers_vae_model_from_ldm(
- class_name,
- original_config,
- checkpoint,
- image_size=image_size,
- scaling_factor=scaling_factor,
- torch_dtype=torch_dtype,
- )
- vae = component["vae"]
- if torch_dtype is not None:
- vae = vae.to(torch_dtype)
-
- return vae
diff --git a/src/diffusers/loaders/controlnet.py b/src/diffusers/loaders/controlnet.py
deleted file mode 100644
index 53b9802d390e..000000000000
--- a/src/diffusers/loaders/controlnet.py
+++ /dev/null
@@ -1,136 +0,0 @@
-# Copyright 2024 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from huggingface_hub.utils import validate_hf_hub_args
-
-from .single_file_utils import (
- create_diffusers_controlnet_model_from_ldm,
- fetch_ldm_config_and_checkpoint,
-)
-
-
-class FromOriginalControlNetMixin:
- """
- Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
- """
-
- @classmethod
- @validate_hf_hub_args
- def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
- r"""
- Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
- `.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
-
- Parameters:
- pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
- Can be either:
- - A link to the `.ckpt` file (for example
- `"https://huggingface.co//blob/main/.ckpt"`) on the Hub.
- - A path to a *file* containing all pipeline weights.
- config_file (`str`, *optional*):
- Filepath to the configuration YAML file associated with the model. If not provided it will default to:
- https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml
- torch_dtype (`str` or `torch.dtype`, *optional*):
- Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
- dtype is automatically derived from the model's weights.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- resume_download:
- Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
- of Diffusers.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to True, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- image_size (`int`, *optional*, defaults to 512):
- The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
- Diffusion v2 base model. Use 768 for Stable Diffusion v2.
- upcast_attention (`bool`, *optional*, defaults to `None`):
- Whether the attention computation should always be upcasted.
- kwargs (remaining dictionary of keyword arguments, *optional*):
- Can be used to overwrite load and saveable variables (for example the pipeline components of the
- specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
- method. See example below for more information.
-
- Examples:
-
- ```py
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
-
- url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
- model = ControlNetModel.from_single_file(url)
-
- url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
- pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
- ```
- """
- original_config_file = kwargs.pop("original_config_file", None)
- config_file = kwargs.pop("config_file", None)
- resume_download = kwargs.pop("resume_download", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- token = kwargs.pop("token", None)
- cache_dir = kwargs.pop("cache_dir", None)
- local_files_only = kwargs.pop("local_files_only", None)
- revision = kwargs.pop("revision", None)
- torch_dtype = kwargs.pop("torch_dtype", None)
-
- class_name = cls.__name__
- if (config_file is not None) and (original_config_file is not None):
- raise ValueError(
- "You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments."
- )
-
- original_config_file = config_file or original_config_file
- original_config, checkpoint = fetch_ldm_config_and_checkpoint(
- pretrained_model_link_or_path=pretrained_model_link_or_path,
- class_name=class_name,
- original_config_file=original_config_file,
- resume_download=resume_download,
- force_download=force_download,
- proxies=proxies,
- token=token,
- revision=revision,
- local_files_only=local_files_only,
- cache_dir=cache_dir,
- )
-
- upcast_attention = kwargs.pop("upcast_attention", False)
- image_size = kwargs.pop("image_size", None)
-
- component = create_diffusers_controlnet_model_from_ldm(
- class_name,
- original_config,
- checkpoint,
- upcast_attention=upcast_attention,
- image_size=image_size,
- torch_dtype=torch_dtype,
- )
- controlnet = component["controlnet"]
- if torch_dtype is not None:
- controlnet = controlnet.to(torch_dtype)
-
- return controlnet
diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py
index d7bf67288c0a..f6e6373ce035 100644
--- a/src/diffusers/loaders/single_file.py
+++ b/src/diffusers/loaders/single_file.py
@@ -555,7 +555,4 @@ def load_module(name, value):
pipe = pipeline_class(**init_kwargs)
- if torch_dtype is not None:
- pipe.to(dtype=torch_dtype)
-
return pipe
diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index ff076c82b00b..c58251139c49 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -1808,4 +1808,17 @@ def create_diffusers_t5_model_from_checkpoint(
else:
model.load_state_dict(diffusers_format_checkpoint)
+
+ use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
+ if use_keep_in_fp32_modules:
+ keep_in_fp32_modules = model._keep_in_fp32_modules
+ else:
+ keep_in_fp32_modules = []
+
+ if keep_in_fp32_modules is not None:
+ for name, param in model.named_parameters():
+ if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
+ # param = param.to(torch.float32) does not work here as only in the local scope.
+ param.data = param.data.to(torch.float32)
+
return model
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index e19b087431a2..2a81f357d48b 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -128,9 +128,9 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_onl
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
- dim_head=attention_head_dim // num_attention_heads,
+ dim_head=attention_head_dim,
heads=num_attention_heads,
- out_dim=attention_head_dim,
+ out_dim=dim,
context_pre_only=context_pre_only,
bias=True,
processor=processor,
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 8784dcda4b6e..9d495695e330 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -2775,6 +2775,26 @@ def __call__(
return hidden_states
+class LoRAAttnProcessor:
+ def __init__(self):
+ pass
+
+
+class LoRAAttnProcessor2_0:
+ def __init__(self):
+ pass
+
+
+class LoRAXFormersAttnProcessor:
+ def __init__(self):
+ pass
+
+
+class LoRAAttnAddedKVProcessor:
+ def __init__(self):
+ pass
+
+
ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
diff --git a/src/diffusers/models/controlnet_hunyuan.py b/src/diffusers/models/controlnet_hunyuan.py
index c97cdaf913b3..4277d81d1cb9 100644
--- a/src/diffusers/models/controlnet_hunyuan.py
+++ b/src/diffusers/models/controlnet_hunyuan.py
@@ -57,6 +57,7 @@ def __init__(
pooled_projection_dim: int = 1024,
text_len: int = 77,
text_len_t5: int = 256,
+ use_style_cond_and_image_meta_size: bool = True,
):
super().__init__()
self.num_heads = num_attention_heads
@@ -87,6 +88,7 @@ def __init__(
pooled_projection_dim=pooled_projection_dim,
seq_len=text_len_t5,
cross_attention_dim=cross_attention_dim_t5,
+ use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,
)
# controlnet_blocks
diff --git a/src/diffusers/models/controlnet_sd3.py b/src/diffusers/models/controlnet_sd3.py
index 629cb661eda5..2b4dd0fa8b72 100644
--- a/src/diffusers/models/controlnet_sd3.py
+++ b/src/diffusers/models/controlnet_sd3.py
@@ -81,7 +81,7 @@ def __init__(
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
- attention_head_dim=self.inner_dim,
+ attention_head_dim=self.config.attention_head_dim,
context_pre_only=False,
)
for i in range(num_layers)
@@ -239,16 +239,16 @@ def _set_gradient_checkpointing(self, module, value=False):
module.gradient_checkpointing = value
@classmethod
- def from_transformer(cls, transformer, num_layers=None, load_weights_from_transformer=True):
+ def from_transformer(cls, transformer, num_layers=12, load_weights_from_transformer=True):
config = transformer.config
config["num_layers"] = num_layers or config.num_layers
controlnet = cls(**config)
if load_weights_from_transformer:
- controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=False)
- controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict(), strict=False)
- controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict(), strict=False)
- controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict())
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index cb64bc61f3e9..cb6cb065dd32 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -415,9 +415,10 @@ def __init__(
if set_W_to_weight:
# to delete later
+ del self.weight
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
-
self.weight = self.W
+ del self.W
def forward(self, x):
if self.log:
@@ -717,18 +718,33 @@ def forward(self, x):
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
- def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
+ def __init__(
+ self,
+ embedding_dim,
+ pooled_projection_dim=1024,
+ seq_len=256,
+ cross_attention_dim=2048,
+ use_style_cond_and_image_meta_size=True,
+ ):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+ self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+
self.pooler = HunyuanDiTAttentionPool(
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
)
+
# Here we use a default learned embedder layer for future extension.
- self.style_embedder = nn.Embedding(1, embedding_dim)
- extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
+ self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
+ if use_style_cond_and_image_meta_size:
+ self.style_embedder = nn.Embedding(1, embedding_dim)
+ extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
+ else:
+ extra_in_dim = pooled_projection_dim
+
self.extra_embedder = PixArtAlphaTextProjection(
in_features=extra_in_dim,
hidden_size=embedding_dim * 4,
@@ -743,16 +759,20 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde
# extra condition1: text
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
- # extra condition2: image meta size embdding
- image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
- image_meta_size = image_meta_size.to(dtype=hidden_dtype)
- image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
+ if self.use_style_cond_and_image_meta_size:
+ # extra condition2: image meta size embdding
+ image_meta_size = self.size_proj(image_meta_size.view(-1))
+ image_meta_size = image_meta_size.to(dtype=hidden_dtype)
+ image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
- # extra condition3: style embedding
- style_embedding = self.style_embedder(style) # (N, embedding_dim)
+ # extra condition3: style embedding
+ style_embedding = self.style_embedder(style) # (N, embedding_dim)
+
+ # Concatenate all extra vectors
+ extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
+ else:
+ extra_cond = torch.cat([pooled_projections], dim=1)
- # Concatenate all extra vectors
- extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
return conditioning
diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py
index d67b35586a60..8313ffd87a50 100644
--- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py
+++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py
@@ -249,6 +249,8 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
The length of the clip text embedding.
text_len_t5 (`int`, *optional*):
The length of the T5 text embedding.
+ use_style_cond_and_image_meta_size (`bool`, *optional*):
+ Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
"""
@register_to_config
@@ -270,6 +272,7 @@ def __init__(
pooled_projection_dim: int = 1024,
text_len: int = 77,
text_len_t5: int = 256,
+ use_style_cond_and_image_meta_size: bool = True,
):
super().__init__()
self.out_channels = in_channels * 2 if learn_sigma else in_channels
@@ -301,6 +304,7 @@ def __init__(
pooled_projection_dim=pooled_projection_dim,
seq_len=text_len_t5,
cross_attention_dim=cross_attention_dim_t5,
+ use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size,
)
# HunyuanDiT Blocks
diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py
index d514a43537d8..1b9126b3b849 100644
--- a/src/diffusers/models/transformers/transformer_sd3.py
+++ b/src/diffusers/models/transformers/transformer_sd3.py
@@ -97,7 +97,7 @@ def __init__(
JointTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
- attention_head_dim=self.inner_dim,
+ attention_head_dim=self.config.attention_head_dim,
context_pre_only=i == num_layers - 1,
)
for i in range(self.config.num_layers)
diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py
index 584014c78c6f..e2657e56901f 100644
--- a/src/diffusers/models/unets/unet_motion_model.py
+++ b/src/diffusers/models/unets/unet_motion_model.py
@@ -597,7 +597,9 @@ def from_unet2d(
if not config.get("num_attention_heads"):
config["num_attention_heads"] = config["attention_head_dim"]
- config = FrozenDict(config)
+ expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
+ config = FrozenDict({k: config.get(k) for k in config if k in expected_kwargs or k in optional_kwargs})
+ config["_class_name"] = cls.__name__
model = cls.from_config(config)
if not load_weights:
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
index a0daf5ced9fc..1fa324cc912f 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
@@ -62,7 +62,7 @@
>>> pipe = pipe.to(device)
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
- >>> init_image = load_image(url).resize((512, 512))
+ >>> init_image = load_image(url).resize((1024, 1024))
>>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py
index d0253ff474d9..ce90fb09193b 100644
--- a/src/diffusers/utils/hub_utils.py
+++ b/src/diffusers/utils/hub_utils.py
@@ -358,42 +358,42 @@ def _get_model_file(
)
return model_file
- except RepositoryNotFoundError:
+ except RepositoryNotFoundError as e:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `token` or log in with `huggingface-cli "
"login`."
- )
- except RevisionNotFoundError:
+ ) from e
+ except RevisionNotFoundError as e:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
- )
- except EntryNotFoundError:
+ ) from e
+ except EntryNotFoundError as e:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
- )
- except HTTPError as err:
+ ) from e
+ except HTTPError as e:
raise EnvironmentError(
- f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
- )
- except ValueError:
+ f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{e}"
+ ) from e
+ except ValueError as e:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {weights_name} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
- )
- except EnvironmentError:
+ ) from e
+ except EnvironmentError as e:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {weights_name}"
- )
+ ) from e
# Adapted from
diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py
index d78479247347..0fc185b602a3 100644
--- a/tests/models/autoencoders/test_models_vae.py
+++ b/tests/models/autoencoders/test_models_vae.py
@@ -361,9 +361,10 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
forward_requires_fresh_args = True
def inputs_dict(self, seed=None):
- generator = torch.Generator("cpu")
- if seed is not None:
- generator.manual_seed(0)
+ if seed is None:
+ generator = torch.Generator("cpu").manual_seed(0)
+ else:
+ generator = torch.Generator("cpu").manual_seed(seed)
image = randn_tensor((4, 3, 32, 32), generator=generator, device=torch.device(torch_device))
return {"sample": image, "generator": generator}
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index ac356d4c522d..259b4cc916d3 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -905,11 +905,13 @@ def test_sharded_checkpoints(self):
actual_num_shards = len([file for file in os.listdir(tmp_dir) if file.endswith(".safetensors")])
self.assertTrue(actual_num_shards == expected_num_shards)
- new_model = self.model_class.from_pretrained(tmp_dir)
+ new_model = self.model_class.from_pretrained(tmp_dir).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
+ _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict)
+
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_gpu
@@ -940,6 +942,7 @@ def test_sharded_checkpoints_device_map(self):
new_model = new_model.to(torch_device)
torch.manual_seed(0)
+ _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
new_output = new_model(**inputs_dict)
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
index d8418a38262e..dc68cc3ecdbd 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
@@ -108,7 +108,6 @@ def get_dummy_components(self):
def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- image = image / 2 + 0.5
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py
index 5157cd4ca63c..b2bb7fe827f9 100644
--- a/tests/single_file/single_file_testing_utils.py
+++ b/tests/single_file/single_file_testing_utils.py
@@ -201,6 +201,20 @@ def test_single_file_components_with_diffusers_config_local_files_only(
self._compare_component_configs(pipe, single_file_pipe)
+ def test_single_file_setting_pipeline_dtype_to_fp16(
+ self,
+ single_file_pipe=None,
+ ):
+ single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
+ self.ckpt_path, torch_dtype=torch.float16
+ )
+
+ for component_name, component in single_file_pipe.components.items():
+ if not isinstance(component, torch.nn.Module):
+ continue
+
+ assert component.dtype == torch.float16
+
class SDXLSingleFileTesterMixin:
def _compare_component_configs(self, pipe, single_file_pipe):
@@ -378,3 +392,17 @@ def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_d
max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten())
assert max_diff < expected_max_diff
+
+ def test_single_file_setting_pipeline_dtype_to_fp16(
+ self,
+ single_file_pipe=None,
+ ):
+ single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
+ self.ckpt_path, torch_dtype=torch.float16
+ )
+
+ for component_name, component in single_file_pipe.components.items():
+ if not isinstance(component, torch.nn.Module):
+ continue
+
+ assert component.dtype == torch.float16
diff --git a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
index 8e9ac7973609..1af3f5126ff3 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
@@ -180,3 +180,12 @@ def test_single_file_components_with_diffusers_config_local_files_only(self):
local_files_only=True,
)
super()._compare_component_configs(pipe, pipe_single_file)
+
+ def test_single_file_setting_pipeline_dtype_to_fp16(self):
+ controlnet = ControlNetModel.from_pretrained(
+ "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
+ )
+ single_file_pipe = self.pipeline_class.from_single_file(
+ self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
+ )
+ super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)
diff --git a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
index 8c750437f719..1966ecfc207a 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
@@ -181,3 +181,12 @@ def test_single_file_components_with_diffusers_config_local_files_only(self):
local_files_only=True,
)
super()._compare_component_configs(pipe, pipe_single_file)
+
+ def test_single_file_setting_pipeline_dtype_to_fp16(self):
+ controlnet = ControlNetModel.from_pretrained(
+ "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
+ )
+ single_file_pipe = self.pipeline_class.from_single_file(
+ self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
+ )
+ super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)
diff --git a/tests/single_file/test_stable_diffusion_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_single_file.py
index abcf4c11d614..fe066f02cf36 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_single_file.py
@@ -169,3 +169,12 @@ def test_single_file_components_with_diffusers_config_local_files_only(self):
local_files_only=True,
)
super()._compare_component_configs(pipe, pipe_single_file)
+
+ def test_single_file_setting_pipeline_dtype_to_fp16(self):
+ controlnet = ControlNetModel.from_pretrained(
+ "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
+ )
+ single_file_pipe = self.pipeline_class.from_single_file(
+ self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
+ )
+ super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)
diff --git a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
index 43881914d3c0..7f478133c66f 100644
--- a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
@@ -200,3 +200,11 @@ def test_single_file_components_with_original_config_local_files_only(self):
local_files_only=True,
)
self._compare_component_configs(pipe, pipe_single_file)
+
+ def test_single_file_setting_pipeline_dtype_to_fp16(self):
+ adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16)
+
+ single_file_pipe = self.pipeline_class.from_single_file(
+ self.ckpt_path, adapter=adapter, torch_dtype=torch.float16
+ )
+ super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)
diff --git a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
index 6aebc2b01999..a8509510ad80 100644
--- a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
@@ -195,3 +195,12 @@ def test_single_file_components_with_diffusers_config_local_files_only(self):
local_files_only=True,
)
super()._compare_component_configs(pipe, pipe_single_file)
+
+ def test_single_file_setting_pipeline_dtype_to_fp16(self):
+ controlnet = ControlNetModel.from_pretrained(
+ "diffusers/controlnet-depth-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
+ )
+ single_file_pipe = self.pipeline_class.from_single_file(
+ self.ckpt_path, controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
+ )
+ super().test_single_file_setting_pipeline_dtype_to_fp16(single_file_pipe)
diff --git a/utils/fetch_torch_cuda_pipeline_test_matrix.py b/utils/fetch_torch_cuda_pipeline_test_matrix.py
index 744201cb2d48..235d346d9306 100644
--- a/utils/fetch_torch_cuda_pipeline_test_matrix.py
+++ b/utils/fetch_torch_cuda_pipeline_test_matrix.py
@@ -88,7 +88,7 @@ def main():
test_modules.extend(ALWAYS_TEST_PIPELINE_MODULES)
# Get unique modules
- test_modules = list(set(test_modules))
+ test_modules = sorted(set(test_modules))
print(json.dumps(test_modules))
save_path = f"{PATH_TO_REPO}/reports"