From 1d1e1a2888bd65b51f13272de2f709fd91e0beb1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 24 Oct 2024 20:19:09 +0530 Subject: [PATCH 01/14] Some minor updates to the nightly and push workflows (#9759) * move lora integration tests to nightly./ * remove slow marker in the workflow where not needed. --- .github/workflows/push_tests.yml | 6 +++--- tests/lora/test_lora_layers_flux.py | 4 +++- tests/lora/test_lora_layers_sd.py | 2 ++ tests/lora/test_lora_layers_sdxl.py | 1 + 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index f07e6cda0d59..2289d1b5cad1 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -81,7 +81,7 @@ jobs: - name: Environment run: | python utils/print_env.py - - name: Slow PyTorch CUDA checkpoint tests on Ubuntu + - name: PyTorch CUDA checkpoint tests on Ubuntu env: HF_TOKEN: ${{ secrets.HF_TOKEN }} # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms @@ -184,7 +184,7 @@ jobs: run: | python utils/print_env.py - - name: Run slow Flax TPU tests + - name: Run Flax TPU tests env: HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | @@ -232,7 +232,7 @@ jobs: run: | python utils/print_env.py - - name: Run slow ONNXRuntime CUDA tests + - name: Run ONNXRuntime CUDA tests env: HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 3bc46d1e9b13..b58525cc7a6f 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -27,6 +27,7 @@ from diffusers.utils.testing_utils import ( floats_tensor, is_peft_available, + nightly, numpy_cosine_similarity_distance, require_peft_backend, require_torch_gpu, @@ -165,9 +166,10 @@ def test_modify_padding_mode(self): @slow +@nightly @require_torch_gpu @require_peft_backend -# @unittest.skip("We cannot run inference on this model with the current CI hardware") +@unittest.skip("We cannot run inference on this model with the current CI hardware") # TODO (DN6, sayakpaul): move these tests to a beefier GPU class FluxLoRAIntegrationTests(unittest.TestCase): """internal note: The integration slices were obtained on audace. diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py index 50187e50a912..e91b0689b4ce 100644 --- a/tests/lora/test_lora_layers_sd.py +++ b/tests/lora/test_lora_layers_sd.py @@ -34,6 +34,7 @@ from diffusers.utils.import_utils import is_accelerate_available from diffusers.utils.testing_utils import ( load_image, + nightly, numpy_cosine_similarity_distance, require_peft_backend, require_torch_gpu, @@ -207,6 +208,7 @@ def test_integration_move_lora_dora_cpu(self): @slow +@nightly @require_torch_gpu @require_peft_backend class LoraIntegrationTests(unittest.TestCase): diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index 94a44ed8f9ec..30238c74873b 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -113,6 +113,7 @@ def tearDown(self): @slow +@nightly @require_torch_gpu @require_peft_backend class LoraSDXLIntegrationTests(unittest.TestCase): From 435f6b7e47c031f98b8374b1689e1abeb17bfdb6 Mon Sep 17 00:00:00 2001 From: Zhiyang Shen <1003151222@qq.com> Date: Fri, 25 Oct 2024 19:03:35 +0800 Subject: [PATCH 02/14] [Docs] fix docstring typo in SD3 pipeline (#9765) * fix docstring typo in SD3 pipeline * fix docstring typo in SD3 pipeline --- .../stable_diffusion_3/pipeline_stable_diffusion_3.py | 4 ++-- .../stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py | 4 ++-- .../stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 4b9df578bc4a..43cb40e6e733 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -762,8 +762,8 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead - of a plain tuple. + Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of + a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 794716303394..a07a056ec851 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -800,8 +800,8 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead - of a plain tuple. + Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of + a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index 7401be39d6f9..d3e0ecf9c3a7 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -921,8 +921,8 @@ def __call__( The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead - of a plain tuple. + Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of + a plain tuple. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, From 94643fac8a27345f695500085d78cc8fa01f5fa9 Mon Sep 17 00:00:00 2001 From: Leo Jiang <74156916+leisuzz@users.noreply.github.com> Date: Fri, 25 Oct 2024 08:35:19 -0600 Subject: [PATCH 03/14] [bugfix] bugfix for npu free memory (#9640) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Improve NPU performance * Improve NPU performance * Improve NPU performance * Improve NPU performance * [bugfix] bugfix for npu free memory * [bugfix] bugfix for npu free memory * [bugfix] bugfix for npu free memory --------- Co-authored-by: 蒋硕 Co-authored-by: Sayak Paul --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 9c898ad141ee..0e0d0ce5b568 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -284,7 +284,7 @@ def free_memory(): elif torch.backends.mps.is_available(): torch.mps.empty_cache() elif is_torch_npu_available(): - torch_npu.empty_cache() + torch_npu.npu.empty_cache() # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 From df073ba1373bf261948d88c3182e27842934e47e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 26 Oct 2024 00:07:57 +0900 Subject: [PATCH 04/14] [research_projects] add flux training script with quantization (#9754) * add flux training script with quantization * remove exclamation --- .../flux_lora_quantization/README.md | 166 +++ .../flux_lora_quantization/accelerate.yaml | 17 + .../compute_embeddings.py | 107 ++ .../flux_lora_quantization/ds2.yaml | 23 + .../train_dreambooth_lora_flux_miniature.py | 1183 +++++++++++++++++ 5 files changed, 1496 insertions(+) create mode 100644 examples/research_projects/flux_lora_quantization/README.md create mode 100644 examples/research_projects/flux_lora_quantization/accelerate.yaml create mode 100644 examples/research_projects/flux_lora_quantization/compute_embeddings.py create mode 100644 examples/research_projects/flux_lora_quantization/ds2.yaml create mode 100644 examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py diff --git a/examples/research_projects/flux_lora_quantization/README.md b/examples/research_projects/flux_lora_quantization/README.md new file mode 100644 index 000000000000..ffec85550e51 --- /dev/null +++ b/examples/research_projects/flux_lora_quantization/README.md @@ -0,0 +1,166 @@ +## LoRA fine-tuning Flux.1 Dev with quantization + +> [!NOTE] +> This example is educational in nature and fixes some arguments to keep things simple. It should act as a reference to build things further. + +This example shows how to fine-tune [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with LoRA and quantization. We show this by using the [`Norod78/Yarn-art-style`](https://huggingface.co/datasets/Norod78/Yarn-art-style) dataset. Steps below summarize the workflow: + +* We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file. +* `train_dreambooth_lora_flux_miniature.py` takes care of training: + * Since we already precomputed the text embeddings, we don't load the text encoders. + * We load the VAE and use it to precompute the image latents and we then delete it. + * Load the Flux transformer, quantize it with the [NF4 datatype](https://arxiv.org/abs/2305.14314) through `bitsandbytes`, prepare it for 4bit training. + * Add LoRA adapter layers to it and then ensure they are kept in FP32 precision. + * Train! + +To run training in a memory-optimized manner, we additionally use: + +* 8Bit Adam +* Gradient checkpointing + +We have tested the scripts on a 24GB 4090. It works on a free-tier Colab Notebook, too, but it's extremely slow. + +## Training + +Ensure you have installed the required libraries: + +```bash +pip install -U transformers accelerate bitsandbytes peft datasets +pip install git+https://github.com/huggingface/diffusers -U +``` + +Now, compute the text embeddings: + +```bash +python compute_embeddings.py +``` + +It should create a file named `embeddings.parquet`. We're then ready to launch training. First, authenticate so that you can access the Flux.1 Dev model: + +```bash +huggingface-cli +``` + +Then launch: + +```bash +accelerate launch --config_file=accelerate.yaml \ + train_dreambooth_lora_flux_miniature.py \ + --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \ + --data_df_path="embeddings.parquet" \ + --output_dir="yarn_art_lora_flux_nf4" \ + --mixed_precision="fp16" \ + --use_8bit_adam \ + --weighting_scheme="none" \ + --resolution=1024 \ + --train_batch_size=1 \ + --repeats=1 \ + --learning_rate=1e-4 \ + --guidance_scale=1 \ + --report_to="wandb" \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --cache_latents \ + --rank=4 \ + --max_train_steps=700 \ + --seed="0" +``` + +We can direcly pass a quantized checkpoint path, too: + +```diff ++ --quantized_model_path="hf-internal-testing/flux.1-dev-nf4-pkg" +``` + +Depending on the machine, training time will vary but for our case, it was 1.5 hours. It maybe possible to speed this up by using `torch.bfloat16`. + +We support training with the DeepSpeed Zero2 optimizer, too. To use it, first install DeepSpeed: + +```bash +pip install -Uq deepspeed +``` + +And then launch: + +```bash +accelerate launch --config_file=ds2.yaml \ + train_dreambooth_lora_flux_miniature.py \ + --pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \ + --data_df_path="embeddings.parquet" \ + --output_dir="yarn_art_lora_flux_nf4" \ + --mixed_precision="no" \ + --use_8bit_adam \ + --weighting_scheme="none" \ + --resolution=1024 \ + --train_batch_size=1 \ + --repeats=1 \ + --learning_rate=1e-4 \ + --guidance_scale=1 \ + --report_to="wandb" \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --cache_latents \ + --rank=4 \ + --max_train_steps=700 \ + --seed="0" +``` + +## Inference + +When loading the LoRA params (that were obtained on a quantized base model) and merging them into the base model, it is recommended to first dequantize the base model, merge the LoRA params into it, and then quantize the model again. This is because merging into 4bit quantized models can lead to some rounding errors. Below, we provide an end-to-end example: + +1. First, load the original model and merge the LoRA params into it: + +```py +from diffusers import FluxPipeline +import torch + +ckpt_id = "black-forest-labs/FLUX.1-dev" +pipeline = FluxPipeline.from_pretrained( + ckpt_id, text_encoder=None, text_encoder_2=None, torch_dtype=torch.float16 +) +pipeline.load_lora_weights("yarn_art_lora_flux_nf4", weight_name="pytorch_lora_weights.safetensors") +pipeline.fuse_lora() +pipeline.unload_lora_weights() + +pipeline.transformer.save_pretrained("fused_transformer") +``` + +2. Quantize the model and run inference + +```py +from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel, BitsAndBytesConfig +import torch + +ckpt_id = "black-forest-labs/FLUX.1-dev" +bnb_4bit_compute_dtype = torch.float16 +nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, +) +transformer = FluxTransformer2DModel.from_pretrained( + "fused_transformer", + quantization_config=nf4_config, + torch_dtype=bnb_4bit_compute_dtype, +) +pipeline = AutoPipelineForText2Image.from_pretrained( + ckpt_id, transformer=transformer, torch_dtype=bnb_4bit_compute_dtype +) +pipeline.enable_model_cpu_offload() + +image = pipeline( + "a puppy in a pond, yarn art style", num_inference_steps=28, guidance_scale=3.5, height=768 +).images[0] +image.save("yarn_merged.png") +``` + +| Dequantize, merge, quantize | Merging directly into quantized model | +|-------|-------| +| ![Image A](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/merged.png) | ![Image B](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/unmerged.png) | + +As we can notice the first column result follows the style more closely. \ No newline at end of file diff --git a/examples/research_projects/flux_lora_quantization/accelerate.yaml b/examples/research_projects/flux_lora_quantization/accelerate.yaml new file mode 100644 index 000000000000..309e13cc140a --- /dev/null +++ b/examples/research_projects/flux_lora_quantization/accelerate.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: NO +downcast_bf16: 'no' +enable_cpu_affinity: true +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/research_projects/flux_lora_quantization/compute_embeddings.py b/examples/research_projects/flux_lora_quantization/compute_embeddings.py new file mode 100644 index 000000000000..8e93af961e65 --- /dev/null +++ b/examples/research_projects/flux_lora_quantization/compute_embeddings.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import pandas as pd +import torch +from datasets import load_dataset +from huggingface_hub.utils import insecure_hashlib +from tqdm.auto import tqdm +from transformers import T5EncoderModel + +from diffusers import FluxPipeline + + +MAX_SEQ_LENGTH = 77 +OUTPUT_PATH = "embeddings.parquet" + + +def generate_image_hash(image): + return insecure_hashlib.sha256(image.tobytes()).hexdigest() + + +def load_flux_dev_pipeline(): + id = "black-forest-labs/FLUX.1-dev" + text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_2", load_in_8bit=True, device_map="auto") + pipeline = FluxPipeline.from_pretrained( + id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map="balanced" + ) + return pipeline + + +@torch.no_grad() +def compute_embeddings(pipeline, prompts, max_sequence_length): + all_prompt_embeds = [] + all_pooled_prompt_embeds = [] + all_text_ids = [] + for prompt in tqdm(prompts, desc="Encoding prompts."): + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=max_sequence_length) + all_prompt_embeds.append(prompt_embeds) + all_pooled_prompt_embeds.append(pooled_prompt_embeds) + all_text_ids.append(text_ids) + + max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 + print(f"Max memory allocated: {max_memory:.3f} GB") + return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids + + +def run(args): + dataset = load_dataset("Norod78/Yarn-art-style", split="train") + image_prompts = {generate_image_hash(sample["image"]): sample["text"] for sample in dataset} + all_prompts = list(image_prompts.values()) + print(f"{len(all_prompts)=}") + + pipeline = load_flux_dev_pipeline() + all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = compute_embeddings( + pipeline, all_prompts, args.max_sequence_length + ) + + data = [] + for i, (image_hash, _) in enumerate(image_prompts.items()): + data.append((image_hash, all_prompt_embeds[i], all_pooled_prompt_embeds[i], all_text_ids[i])) + print(f"{len(data)=}") + + # Create a DataFrame + embedding_cols = ["prompt_embeds", "pooled_prompt_embeds", "text_ids"] + df = pd.DataFrame(data, columns=["image_hash"] + embedding_cols) + print(f"{len(df)=}") + + # Convert embedding lists to arrays (for proper storage in parquet) + for col in embedding_cols: + df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist()) + + # Save the dataframe to a parquet file + df.to_parquet(args.output_path) + print(f"Data successfully serialized to {args.output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--max_sequence_length", + type=int, + default=MAX_SEQ_LENGTH, + help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.", + ) + parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.") + args = parser.parse_args() + + run(args) diff --git a/examples/research_projects/flux_lora_quantization/ds2.yaml b/examples/research_projects/flux_lora_quantization/ds2.yaml new file mode 100644 index 000000000000..beed28fd90ab --- /dev/null +++ b/examples/research_projects/flux_lora_quantization/ds2.yaml @@ -0,0 +1,23 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + gradient_clipping: 1.0 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: 'no' +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py new file mode 100644 index 000000000000..fd2b5568d6d8 --- /dev/null +++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py @@ -0,0 +1,1183 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator, DistributedType +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from peft import LoraConfig, prepare_model_for_kbit_training, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm + +import diffusers +from diffusers import ( + AutoencoderKL, + BitsAndBytesConfig, + FlowMatchEulerDiscreteScheduler, + FluxPipeline, + FluxTransformer2DModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + cast_training_params, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, + free_memory, +) +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + pass + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.31.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + base_model: str = None, + instance_prompt=None, + repo_folder=None, + quantization_config=None, +): + widget_dict = [] + + model_description = f""" +# Flux DreamBooth LoRA - {repo_id} + + + +## Model description + +These are {repo_id} DreamBooth LoRA weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md). + +Was LoRA for the text encoder enabled? False. + +Quantization config: + +```yaml +{quantization_config} +``` + +## Trigger words + +You should use `{instance_prompt}` to trigger the image generation. + +## Download model + +[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab. + +For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) + +## Usage + +TODO + +## License + +Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md). +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="other", + base_model=base_model, + prompt=instance_prompt, + model_description=model_description, + widget=widget_dict, + ) + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "lora", + "flux", + "flux-diffusers", + "template:sd-lora", + ] + + model_card = populate_model_card(model_card, tags=tags) + model_card.save(os.path.join(repo_folder, "README.md")) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--quantized_model_path", + type=str, + default=None, + help="Path to the quantized model.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--data_df_path", + type=str, + default=None, + help=("Path to the parquet file serialized with compute_embeddings.py."), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--max_sequence_length", + type=int, + default=77, + help="Used for reading the embeddings. Needs to be the same as used during `compute_embeddings.py`.", + ) + + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + + parser.add_argument( + "--output_dir", + type=str, + default="flux-dreambooth-lora-nf4", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--guidance_scale", + type=float, + default=3.5, + help="the FLUX.1 dev variant is a guidance distilled model", + ) + + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="none", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], + help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + + parser.add_argument( + "--cache_latents", + action="store_true", + default=False, + help="Cache the VAE latents", + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +class DreamBoothDataset(Dataset): + def __init__( + self, + data_df_path, + dataset_name, + size=1024, + max_sequence_length=77, + center_crop=False, + ): + # Logistics + self.size = size + self.center_crop = center_crop + self.max_sequence_length = max_sequence_length + + self.data_df_path = Path(data_df_path) + if not self.data_df_path.exists(): + raise ValueError("`data_df_path` doesn't exists.") + + # Load images. + dataset = load_dataset(dataset_name, split="train") + instance_images = [sample["image"] for sample in dataset] + image_hashes = [self.generate_image_hash(image) for image in instance_images] + self.instance_images = instance_images + self.image_hashes = image_hashes + + # Image transformations + self.pixel_values = self.apply_image_transformations( + instance_images=instance_images, size=size, center_crop=center_crop + ) + + # Map hashes to embeddings. + self.data_dict = self.map_image_hash_embedding(data_df_path=data_df_path) + + self.num_instance_images = len(instance_images) + self._length = self.num_instance_images + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.pixel_values[index % self.num_instance_images] + image_hash = self.image_hashes[index % self.num_instance_images] + prompt_embeds, pooled_prompt_embeds, text_ids = self.data_dict[image_hash] + example["instance_images"] = instance_image + example["prompt_embeds"] = prompt_embeds + example["pooled_prompt_embeds"] = pooled_prompt_embeds + example["text_ids"] = text_ids + return example + + def apply_image_transformations(self, instance_images, size, center_crop): + pixel_values = [] + + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + for image in instance_images: + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + image = train_transforms(image) + pixel_values.append(image) + + return pixel_values + + def convert_to_torch_tensor(self, embeddings: list): + prompt_embeds = embeddings[0] + pooled_prompt_embeds = embeddings[1] + text_ids = embeddings[2] + prompt_embeds = np.array(prompt_embeds).reshape(self.max_sequence_length, 4096) + pooled_prompt_embeds = np.array(pooled_prompt_embeds).reshape(768) + text_ids = np.array(text_ids).reshape(77, 3) + return torch.from_numpy(prompt_embeds), torch.from_numpy(pooled_prompt_embeds), torch.from_numpy(text_ids) + + def map_image_hash_embedding(self, data_df_path): + hashes_df = pd.read_parquet(data_df_path) + data_dict = {} + for i, row in hashes_df.iterrows(): + embeddings = [row["prompt_embeds"], row["pooled_prompt_embeds"], row["text_ids"]] + prompt_embeds, pooled_prompt_embeds, text_ids = self.convert_to_torch_tensor(embeddings=embeddings) + data_dict.update({row["image_hash"]: (prompt_embeds, pooled_prompt_embeds, text_ids)}) + return data_dict + + def generate_image_hash(self, image): + return insecure_hashlib.sha256(image.tobytes()).hexdigest() + + +def collate_fn(examples): + pixel_values = [example["instance_images"] for example in examples] + prompt_embeds = [example["prompt_embeds"] for example in examples] + pooled_prompt_embeds = [example["pooled_prompt_embeds"] for example in examples] + text_ids = [example["text_ids"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + prompt_embeds = torch.stack(prompt_embeds) + pooled_prompt_embeds = torch.stack(pooled_prompt_embeds) + text_ids = torch.stack(text_ids)[0] # just 2D tensor + + batch = { + "pixel_values": pixel_values, + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + "text_ids": text_ids, + } + return batch + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + ).repo_id + + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + bnb_4bit_compute_dtype = torch.float32 + if args.mixed_precision == "fp16": + bnb_4bit_compute_dtype = torch.float16 + elif args.mixed_precision == "bf16": + bnb_4bit_compute_dtype = torch.bfloat16 + if args.quantized_model_path is not None: + transformer = FluxTransformer2DModel.from_pretrained( + args.quantized_model_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + torch_dtype=bnb_4bit_compute_dtype, + ) + else: + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + ) + transformer = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + quantization_config=nf4_config, + torch_dtype=bnb_4bit_compute_dtype, + ) + transformer = prepare_model_for_kbit_training(transformer, use_gradient_checkpointing=False) + + # We only train the additional adapter LoRA layers + transformer.requires_grad_(False) + vae.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + vae.to(accelerator.device, dtype=weight_dtype) + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + + # now we will add new LoRA weights to the attention layers + transformer_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + transformer.add_adapter(transformer_lora_config) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + + for model in models: + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + model = unwrap_model(model) + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + + FluxPipeline.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + text_encoder_lora_layers=None, + ) + + def load_model_hook(models, input_dir): + transformer_ = None + + if not accelerator.distributed_type == DistributedType.DEEPSPEED: + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(transformer))): + transformer_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + if args.quantized_model_path is not None: + transformer_ = FluxTransformer2DModel.from_pretrained( + args.quantized_model_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + torch_dtype=bnb_4bit_compute_dtype, + ) + else: + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + ) + transformer_ = FluxTransformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + quantization_config=nf4_config, + torch_dtype=bnb_4bit_compute_dtype, + ) + transformer_ = prepare_model_for_kbit_training(transformer_, use_gradient_checkpointing=False) + transformer_.add_adapter(transformer_lora_config) + + lora_state_dict = FluxPipeline.lora_state_dict(input_dir) + + transformer_state_dict = { + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") + } + transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) + incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Make sure the trainable params are in float32. This is again needed since the base models + # are in `weight_dtype`. More details: + # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 + if args.mixed_precision == "fp16": + models = [transformer_] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [transformer] + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(models, dtype=torch.float32) + + transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) + + # Optimization parameters + transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} + params_to_optimize = [transformer_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + data_df_path=args.data_df_path, + dataset_name="Norod78/Yarn-art-style", + size=args.resolution, + max_sequence_length=args.max_sequence_length, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + ) + + vae_config_shift_factor = vae.config.shift_factor + vae_config_scaling_factor = vae.config.scaling_factor + vae_config_block_out_channels = vae.config.block_out_channels + if args.cache_latents: + latents_cache = [] + for batch in tqdm(train_dataloader, desc="Caching latents"): + with torch.no_grad(): + batch["pixel_values"] = batch["pixel_values"].to( + accelerator.device, non_blocking=True, dtype=weight_dtype + ) + latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + + del vae + free_memory() + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_name = "dreambooth-flux-dev-lora-nf4" + accelerator.init_trackers(tracker_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + for epoch in range(first_epoch, args.num_train_epochs): + transformer.train() + + for step, batch in enumerate(train_dataloader): + models_to_accumulate = [transformer] + with accelerator.accumulate(models_to_accumulate): + # Convert images to latent space + if args.cache_latents: + model_input = latents_cache[step].sample() + else: + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor + model_input = model_input.to(dtype=weight_dtype) + + vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) + + latent_image_ids = FluxPipeline._prepare_latent_image_ids( + model_input.shape[0], + model_input.shape[2], + model_input.shape[3], + accelerator.device, + weight_dtype, + ) + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + packed_noisy_model_input = FluxPipeline._pack_latents( + noisy_model_input, + batch_size=model_input.shape[0], + num_channels_latents=model_input.shape[1], + height=model_input.shape[2], + width=model_input.shape[3], + ) + + # handle guidance + if transformer.config.guidance_embeds: + guidance = torch.tensor([args.guidance_scale], device=accelerator.device) + guidance = guidance.expand(model_input.shape[0]) + else: + guidance = None + + # Predict the noise + prompt_embeds = batch["prompt_embeds"].to(device=accelerator.device, dtype=weight_dtype) + pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(device=accelerator.device, dtype=weight_dtype) + text_ids = batch["text_ids"].to(device=accelerator.device, dtype=weight_dtype) + model_pred = transformer( + hidden_states=packed_noisy_model_input, + # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) + timestep=timesteps / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + return_dict=False, + )[0] + model_pred = FluxPipeline._unpack_latents( + model_pred, + height=int(model_input.shape[2] * vae_scale_factor / 2), + width=int(model_input.shape[3] * vae_scale_factor / 2), + vae_scale_factor=vae_scale_factor, + ) + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + target = noise - model_input + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + accelerator.backward(loss) + + if accelerator.sync_gradients: + params_to_clip = transformer.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + transformer = unwrap_model(transformer) + transformer_lora_layers = get_peft_model_state_dict(transformer) + + FluxPipeline.save_lora_weights( + save_directory=args.output_dir, + transformer_lora_layers=transformer_lora_layers, + text_encoder_lora_layers=None, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + base_model=args.pretrained_model_name_or_path, + instance_prompt=None, + repo_folder=args.output_dir, + quantization_config=transformer.config["quantization_config"], + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) From 52d4449810c8e13eb22b57e706e0e03806247da2 Mon Sep 17 00:00:00 2001 From: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com> Date: Fri, 25 Oct 2024 17:24:58 +0200 Subject: [PATCH 05/14] Add a doc for AWS Neuron in Diffusers (#9766) * start draft * add doc * Update docs/source/en/optimization/neuron.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/optimization/neuron.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/optimization/neuron.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/optimization/neuron.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/optimization/neuron.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/optimization/neuron.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/optimization/neuron.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * bref intro of ON * Update docs/source/en/optimization/neuron.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/_toctree.yml | 2 + docs/source/en/optimization/neuron.md | 61 +++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 docs/source/en/optimization/neuron.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 58218c0272bd..87ff9b1fb81a 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -188,6 +188,8 @@ title: Metal Performance Shaders (MPS) - local: optimization/habana title: Habana Gaudi + - local: optimization/neuron + title: AWS Neuron title: Optimized hardware title: Accelerate inference and reduce memory - sections: diff --git a/docs/source/en/optimization/neuron.md b/docs/source/en/optimization/neuron.md new file mode 100644 index 000000000000..b10050e64d7f --- /dev/null +++ b/docs/source/en/optimization/neuron.md @@ -0,0 +1,61 @@ + + +# AWS Neuron + +Diffusers functionalities are available on [AWS Inf2 instances](https://aws.amazon.com/ec2/instance-types/inf2/), which are EC2 instances powered by [Neuron machine learning accelerators](https://aws.amazon.com/machine-learning/inferentia/). These instances aim to provide better compute performance (higher throughput, lower latency) with good cost-efficiency, making them good candidates for AWS users to deploy diffusion models to production. + +[Optimum Neuron](https://huggingface.co/docs/optimum-neuron/en/index) is the interface between Hugging Face libraries and AWS Accelerators, including AWS [Trainium](https://aws.amazon.com/machine-learning/trainium/) and AWS [Inferentia](https://aws.amazon.com/machine-learning/inferentia/). It supports many of the features in Diffusers with similar APIs, so it is easier to learn if you're already familiar with Diffusers. Once you have created an AWS Inf2 instance, install Optimum Neuron. + +```bash +python -m pip install --upgrade-strategy eager optimum[neuronx] +``` + + + +We provide pre-built [Hugging Face Neuron Deep Learning AMI](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2) (DLAMI) and Optimum Neuron containers for Amazon SageMaker. It's recommended to correctly set up your environment. + + + +The example below demonstrates how to generate images with the Stable Diffusion XL model on an inf2.8xlarge instance (you can switch to cheaper inf2.xlarge instances once the model is compiled). To generate some images, use the [`~optimum.neuron.NeuronStableDiffusionXLPipeline`] class, which is similar to the [`StableDiffusionXLPipeline`] class in Diffusers. + +Unlike Diffusers, you need to compile models in the pipeline to the Neuron format, `.neuron`. Launch the following command to export the model to the `.neuron` format. + +```bash +optimum-cli export neuron --model stabilityai/stable-diffusion-xl-base-1.0 \ + --batch_size 1 \ + --height 1024 `# height in pixels of generated image, eg. 768, 1024` \ + --width 1024 `# width in pixels of generated image, eg. 768, 1024` \ + --num_images_per_prompt 1 `# number of images to generate per prompt, defaults to 1` \ + --auto_cast matmul `# cast only matrix multiplication operations` \ + --auto_cast_type bf16 `# cast operations from FP32 to BF16` \ + sd_neuron_xl/ +``` + +Now generate some images with the pre-compiled SDXL model. + +```python +>>> from optimum.neuron import NeuronStableDiffusionXLPipeline + +>>> stable_diffusion_xl = NeuronStableDiffusionXLPipeline.from_pretrained("sd_neuron_xl/") +>>> prompt = "a pig with wings flying in floating US dollar banknotes in the air, skyscrapers behind, warm color palette, muted colors, detailed, 8k" +>>> image = stable_diffusion_xl(prompt).images[0] +``` + +peggy generated by sdxl on inf2 + +Feel free to check out more guides and examples on different use cases from the Optimum Neuron [documentation](https://huggingface.co/docs/optimum-neuron/en/inference_tutorials/stable_diffusion#generate-images-with-stable-diffusion-models-on-aws-inferentia)! From 73b59f5203b5df71175dfd71f613b9bd380b4531 Mon Sep 17 00:00:00 2001 From: Ina <1224084650@qq.com> Date: Sat, 26 Oct 2024 05:01:51 +0800 Subject: [PATCH 06/14] [refactor] enhance readability of flux related pipelines (#9711) * flux pipline: readability enhancement. --- .../train_dreambooth_lora_flux_advanced.py | 8 ++--- examples/controlnet/train_controlnet_flux.py | 4 +-- examples/dreambooth/train_dreambooth_flux.py | 10 +++--- .../dreambooth/train_dreambooth_lora_flux.py | 10 +++--- src/diffusers/pipelines/flux/pipeline_flux.py | 26 +++++++------- .../flux/pipeline_flux_controlnet.py | 26 +++++++------- ...pipeline_flux_controlnet_image_to_image.py | 28 ++++++++------- .../pipeline_flux_controlnet_inpainting.py | 34 +++++++++++-------- .../pipelines/flux/pipeline_flux_img2img.py | 28 ++++++++------- .../pipelines/flux/pipeline_flux_inpaint.py | 32 +++++++++-------- 10 files changed, 110 insertions(+), 96 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index e3e46ead8ee3..ccc390ab7b2c 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -2198,8 +2198,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], - model_input.shape[2], - model_input.shape[3], + model_input.shape[2] // 2, + model_input.shape[3] // 2, accelerator.device, weight_dtype, ) @@ -2253,8 +2253,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): )[0] model_pred = FluxPipeline._unpack_latents( model_pred, - height=int(model_input.shape[2] * vae_scale_factor / 2), - width=int(model_input.shape[3] * vae_scale_factor / 2), + height=model_input.shape[2] * vae_scale_factor, + width=model_input.shape[3] * vae_scale_factor, vae_scale_factor=vae_scale_factor, ) diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py index ca822b16eae2..2958a9e5f28f 100644 --- a/examples/controlnet/train_controlnet_flux.py +++ b/examples/controlnet/train_controlnet_flux.py @@ -1256,8 +1256,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): latent_image_ids = FluxControlNetPipeline._prepare_latent_image_ids( batch_size=pixel_latents_tmp.shape[0], - height=pixel_latents_tmp.shape[2], - width=pixel_latents_tmp.shape[3], + height=pixel_latents_tmp.shape[2] // 2, + width=pixel_latents_tmp.shape[3] // 2, device=pixel_values.device, dtype=pixel_values.dtype, ) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index db4788281cf2..add266d3ac0c 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1540,12 +1540,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor model_input = model_input.to(dtype=weight_dtype) - vae_scale_factor = 2 ** (len(vae.config.block_out_channels)) + vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], - model_input.shape[2], - model_input.shape[3], + model_input.shape[2] // 2, + model_input.shape[3] // 2, accelerator.device, weight_dtype, ) @@ -1601,8 +1601,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042 model_pred = FluxPipeline._unpack_latents( model_pred, - height=int(model_input.shape[2] * vae_scale_factor / 2), - width=int(model_input.shape[3] * vae_scale_factor / 2), + height=model_input.shape[2] * vae_scale_factor, + width=model_input.shape[3] * vae_scale_factor, vae_scale_factor=vae_scale_factor, ) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index b09e5b38b2b1..fa4db10f4f7b 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1645,12 +1645,12 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) - vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) + vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], - model_input.shape[2], - model_input.shape[3], + model_input.shape[2] // 2, + model_input.shape[3] // 2, accelerator.device, weight_dtype, ) @@ -1704,8 +1704,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): )[0] model_pred = FluxPipeline._unpack_latents( model_pred, - height=int(model_input.shape[2] * vae_scale_factor / 2), - width=int(model_input.shape[3] * vae_scale_factor / 2), + height=model_input.shape[2] * vae_scale_factor, + width=model_input.shape[3] * vae_scale_factor, vae_scale_factor=vae_scale_factor, ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 8278365e9467..040d935f1b88 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -195,13 +195,13 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_sample_size = 64 + self.default_sample_size = 128 def _get_t5_prompt_embeds( self, @@ -386,8 +386,10 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -425,9 +427,9 @@ def check_inputs( @staticmethod def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -452,10 +454,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor): height = height // vae_scale_factor width = width // vae_scale_factor - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents @@ -499,8 +501,8 @@ def prepare_latents( generator, latents=None, ): - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor shape = (batch_size, num_channels_latents, height, width) @@ -517,7 +519,7 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents, latent_image_ids diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 5136c4200147..9f33e26013d5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -216,13 +216,13 @@ def __init__( controlnet=controlnet, ) self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_sample_size = 64 + self.default_sample_size = 128 def _get_t5_prompt_embeds( self, @@ -410,8 +410,10 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -450,9 +452,9 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -479,10 +481,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor): height = height // vae_scale_factor width = width // vae_scale_factor - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents @@ -498,8 +500,8 @@ def prepare_latents( generator, latents=None, ): - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor shape = (batch_size, num_channels_latents, height, width) @@ -516,7 +518,7 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents, latent_image_ids diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 8d636feeae05..810c970ab715 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -228,13 +228,13 @@ def __init__( controlnet=controlnet, ) self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_sample_size = 64 + self.default_sample_size = 128 # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( @@ -453,8 +453,10 @@ def check_inputs( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -493,9 +495,9 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -522,10 +524,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor): height = height // vae_scale_factor width = width // vae_scale_factor - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents @@ -549,11 +551,11 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) if latents is not None: return latents.to(device=device, dtype=dtype), latent_image_ids @@ -852,7 +854,7 @@ def __call__( control_mode = control_mode.reshape([-1, 1]) sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 46784f2d46d1..3ca2de633fcf 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -231,7 +231,7 @@ def __init__( ) self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( @@ -244,7 +244,7 @@ def __init__( self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_sample_size = 64 + self.default_sample_size = 128 # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( @@ -467,8 +467,10 @@ def check_inputs( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -520,9 +522,9 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -549,10 +551,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor): height = height // vae_scale_factor width = width // vae_scale_factor - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents @@ -576,11 +578,11 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) @@ -622,8 +624,8 @@ def prepare_mask_latents( device, generator, ): - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision @@ -996,7 +998,9 @@ def __call__( # 6. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = (int(global_height) // self.vae_scale_factor) * (int(global_width) // self.vae_scale_factor) + image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * ( + int(global_width) // self.vae_scale_factor // 2 + ) mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index 112260003ef5..47f9f268ee9d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -212,13 +212,13 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_sample_size = 64 + self.default_sample_size = 128 # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( @@ -437,8 +437,10 @@ def check_inputs( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -477,9 +479,9 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -506,10 +508,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor): height = height // vae_scale_factor width = width // vae_scale_factor - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents @@ -532,11 +534,11 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) if latents is not None: return latents.to(device=device, dtype=dtype), latent_image_ids @@ -736,7 +738,7 @@ def __call__( # 4.Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index ae348c0f6421..766f9864839e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -209,7 +209,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( @@ -222,7 +222,7 @@ def __init__( self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) - self.default_sample_size = 64 + self.default_sample_size = 128 # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( @@ -445,8 +445,10 @@ def check_inputs( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -498,9 +500,9 @@ def check_inputs( @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height // 2, width // 2, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape @@ -527,10 +529,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor): height = height // vae_scale_factor width = width // vae_scale_factor - latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) - latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) return latents @@ -553,11 +555,11 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) @@ -598,8 +600,8 @@ def prepare_mask_latents( device, generator, ): - height = 2 * (int(height) // self.vae_scale_factor) - width = 2 * (int(width) // self.vae_scale_factor) + height = int(height) // self.vae_scale_factor + width = int(width) // self.vae_scale_factor # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision @@ -866,7 +868,7 @@ def __call__( # 4.Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) - image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, self.scheduler.config.base_image_seq_len, From 298ab6eb01f3ef475c15218ea87de1494e1250aa Mon Sep 17 00:00:00 2001 From: SahilCarterr <110806554+SahilCarterr@users.noreply.github.com> Date: Sat, 26 Oct 2024 03:20:55 +0530 Subject: [PATCH 07/14] Added Support of Xlabs controlnet to FluxControlNetInpaintPipeline (#9770) * added xlabs support --- .../pipeline_flux_controlnet_inpainting.py | 61 +++++++++++-------- 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 3ca2de633fcf..1f5f83561f1c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -932,19 +932,22 @@ def __call__( ) height, width = control_image.shape[-2:] - # vae encode - control_image = self.vae.encode(control_image).latent_dist.sample() - control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # pack - height_control_image, width_control_image = control_image.shape[2:] - control_image = self._pack_latents( - control_image, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True + if self.controlnet.input_hint_block is None: + # vae encode + control_image = self.vae.encode(control_image).latent_dist.sample() + control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) # set control mode if control_mode is not None: @@ -954,7 +957,9 @@ def __call__( elif isinstance(self.controlnet, FluxMultiControlNetModel): control_images = [] - for control_image_ in control_image: + # xlab controlnet has a input_hint_block and instantx controlnet does not + controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True + for i, control_image_ in enumerate(control_image): control_image_ = self.prepare_image( image=control_image_, width=width, @@ -966,19 +971,20 @@ def __call__( ) height, width = control_image_.shape[-2:] - # vae encode - control_image_ = self.vae.encode(control_image_).latent_dist.sample() - control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor - - # pack - height_control_image, width_control_image = control_image_.shape[2:] - control_image_ = self._pack_latents( - control_image_, - batch_size * num_images_per_prompt, - num_channels_latents, - height_control_image, - width_control_image, - ) + if self.controlnet.nets[0].input_hint_block is None: + # vae encode + control_image_ = self.vae.encode(control_image_).latent_dist.sample() + control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor + + # pack + height_control_image, width_control_image = control_image_.shape[2:] + control_image_ = self._pack_latents( + control_image_, + batch_size * num_images_per_prompt, + num_channels_latents, + height_control_image, + width_control_image, + ) control_images.append(control_image_) @@ -1129,6 +1135,7 @@ def __call__( img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, )[0] # compute the previous noisy sample x_t -> x_t-1 From fddbab79932eedf1a78041ef38c47df80ab84c90 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 26 Oct 2024 22:13:03 +0900 Subject: [PATCH 08/14] [research_projects] Update README.md to include a note about NF5 T5-xxl (#9775) Update README.md --- examples/research_projects/flux_lora_quantization/README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/research_projects/flux_lora_quantization/README.md b/examples/research_projects/flux_lora_quantization/README.md index ffec85550e51..51005b640221 100644 --- a/examples/research_projects/flux_lora_quantization/README.md +++ b/examples/research_projects/flux_lora_quantization/README.md @@ -5,7 +5,8 @@ This example shows how to fine-tune [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with LoRA and quantization. We show this by using the [`Norod78/Yarn-art-style`](https://huggingface.co/datasets/Norod78/Yarn-art-style) dataset. Steps below summarize the workflow: -* We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file. +* We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file. + * Even though optional, we load the T5-xxl in NF4 to further reduce the memory foot-print. * `train_dreambooth_lora_flux_miniature.py` takes care of training: * Since we already precomputed the text embeddings, we don't load the text encoders. * We load the VAE and use it to precompute the image latents and we then delete it. @@ -163,4 +164,4 @@ image.save("yarn_merged.png") |-------|-------| | ![Image A](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/merged.png) | ![Image B](https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/quantized_flux_training/unmerged.png) | -As we can notice the first column result follows the style more closely. \ No newline at end of file +As we can notice the first column result follows the style more closely. From 3b5b1c56983004ca1ee4190d0eb65f98b0101d39 Mon Sep 17 00:00:00 2001 From: "Vinh H. Pham" Date: Mon, 28 Oct 2024 17:52:27 +0700 Subject: [PATCH 09/14] [Fix] train_dreambooth_lora_flux_advanced ValueError: unexpected save model: (#9777) fix save state te T5 --- .../train_dreambooth_lora_flux_advanced.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py index ccc390ab7b2c..92d296c0f1e8 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py @@ -1650,6 +1650,8 @@ def save_model_hook(models, weights, output_dir): elif isinstance(model, type(unwrap_model(text_encoder_one))): if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) + elif isinstance(model, type(unwrap_model(text_encoder_two))): + pass # when --train_text_encoder_ti and --enable_t5_ti we don't save the layers else: raise ValueError(f"unexpected save model: {model.__class__}") From 493aa74312d4ef86896a5dfc78f671a9d19b24aa Mon Sep 17 00:00:00 2001 From: Biswaroop Date: Mon, 28 Oct 2024 12:07:30 +0100 Subject: [PATCH 10/14] [Fix] remove setting lr for T5 text encoder when using prodigy in flux dreambooth lora script (#9473) * fix: removed setting of text encoder lr for T5 as it's not being tuned * fix: removed setting of text encoder lr for T5 as it's not being tuned --------- Co-authored-by: Sayak Paul Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> --- examples/dreambooth/train_dreambooth_flux.py | 1 - examples/dreambooth/train_dreambooth_lora_flux.py | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index add266d3ac0c..8ab6f4bb6c30 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -1288,7 +1288,6 @@ def load_model_hook(models, input_dir): # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be # --learning_rate params_to_optimize[1]["lr"] = args.learning_rate - params_to_optimize[2]["lr"] = args.learning_rate optimizer = optimizer_class( params_to_optimize, diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index fa4db10f4f7b..5df071b19121 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1370,7 +1370,6 @@ def load_model_hook(models, input_dir): # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be # --learning_rate params_to_optimize[1]["lr"] = args.learning_rate - params_to_optimize[2]["lr"] = args.learning_rate optimizer = optimizer_class( params_to_optimize, From db5b6a963015b885f368da56409d17e88bf4d200 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 28 Oct 2024 16:07:54 +0200 Subject: [PATCH 11/14] [SD 3.5 Dreambooth LoRA] support configurable training block & layers (#9762) * configurable layers * configurable layers * update README * style * add test * style * add layer test, update readme, add nargs * readme * test style * remove print, change nargs * test arg change * style * revert nargs 2/2 * address sayaks comments * style * address sayaks comments --- examples/dreambooth/README_sd3.md | 34 +++++++++ .../dreambooth/test_dreambooth_lora_sd3.py | 71 +++++++++++++++++++ .../dreambooth/train_dreambooth_lora_sd3.py | 39 +++++++++- 3 files changed, 143 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/README_sd3.md b/examples/dreambooth/README_sd3.md index a340be350db8..89d87d65dd44 100644 --- a/examples/dreambooth/README_sd3.md +++ b/examples/dreambooth/README_sd3.md @@ -147,6 +147,40 @@ accelerate launch train_dreambooth_lora_sd3.py \ --push_to_hub ``` +### Targeting Specific Blocks & Layers +As image generation models get bigger & more powerful, more fine-tuners come to find that training only part of the +transformer blocks (sometimes as little as two) can be enough to get great results. +In some cases, it can be even better to maintain some of the blocks/layers frozen. + +For **SD3.5-Large** specifically, you may find this information useful (taken from: [Stable Diffusion 3.5 Large Fine-tuning Tutorial](https://stabilityai.notion.site/Stable-Diffusion-3-5-Large-Fine-tuning-Tutorial-11a61cdcd1968027a15bdbd7c40be8c6#12461cdcd19680788a23c650dab26b93): +> [!NOTE] +> A commonly believed heuristic that we verified once again during the construction of the SD3.5 family of models is that later/higher layers (i.e. `30 - 37`)* impact tertiary details more heavily. Conversely, earlier layers (i.e. `12 - 24` )* influence the overall composition/primary form more. +> So, freezing other layers/targeting specific layers is a viable approach. +> `*`These suggested layers are speculative and not 100% guaranteed. The tips here are more or less a general idea for next steps. +> **Photorealism** +> In preliminary testing, we observed that freezing the last few layers of the architecture significantly improved model training when using a photorealistic dataset, preventing detail degradation introduced by small dataset from happening. +> **Anatomy preservation** +> To dampen any possible degradation of anatomy, training only the attention layers and **not** the adaptive linear layers could help. For reference, below is one of the transformer blocks. + + +We've added `--lora_layers` and `--lora_blocks` to make LoRA training modules configurable. +- with `--lora_blocks` you can specify the block numbers for training. E.g. passing - +```diff +--lora_blocks "12,13,14,15,16,17,18,19,20,21,22,23,24,30,31,32,33,34,35,36,37" +``` +will trigger LoRA training of transformer blocks 12-24 and 30-37. By default, all blocks are trained. +- with `--lora_layers` you can specify the types of layers you wish to train. +By default, the trained layers are - +`attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,attn.to_k,attn.to_out.0,attn.to_q,attn.to_v` +If you wish to have a leaner LoRA / train more blocks over layers you could pass - +```diff ++ --lora_layers attn.to_k,attn.to_q,attn.to_v,attn.to_out.0 +``` +This will reduce LoRA size by roughly 50% for the same rank compared to the default. +However, if you're after compact LoRAs, it's our impression that maintaining the default setting for `--lora_layers` and +freezing some of the early & blocks is usually better. + + ### Text Encoder Training Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported. To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind: diff --git a/examples/dreambooth/test_dreambooth_lora_sd3.py b/examples/dreambooth/test_dreambooth_lora_sd3.py index ec323be4143e..5d6c8bb9938a 100644 --- a/examples/dreambooth/test_dreambooth_lora_sd3.py +++ b/examples/dreambooth/test_dreambooth_lora_sd3.py @@ -38,6 +38,9 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate): pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe" script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py" + transformer_block_idx = 0 + layer_type = "attn.to_k" + def test_dreambooth_lora_sd3(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" @@ -136,6 +139,74 @@ def test_dreambooth_lora_latent_caching(self): starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) self.assertTrue(starts_with_transformer) + def test_dreambooth_lora_block(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --lora_blocks {self.transformer_block_idx} + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. + # In this test, only params of transformer block 0 should be in the state dict + starts_with_transformer = all( + key.startswith("transformer.transformer_blocks.0") for key in lora_state_dict.keys() + ) + self.assertTrue(starts_with_transformer) + + def test_dreambooth_lora_layer(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --lora_layers {self.layer_type} + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # In this test, only transformer params of attention layers `attn.to_k` should be in the state dict + starts_with_transformer = all("attn.to_k" in key for key in lora_state_dict.keys()) + self.assertTrue(starts_with_transformer) + def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 4b39dcfe41b0..fc3c69b8901f 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -571,6 +571,25 @@ def parse_args(input_args=None): "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + "The transformer block layers to apply LoRA training on. Please specify the layers in a comma seperated string." + "For examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md" + ), + ) + parser.add_argument( + "--lora_blocks", + type=str, + default=None, + help=( + "The transformer blocks to apply LoRA training on. Please specify the block numbers in a comma seperated manner." + 'E.g. - "--lora_blocks 12,30" will result in lora training of transformer blocks 12 and 30. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md' + ), + ) + parser.add_argument( "--adam_epsilon", type=float, @@ -1222,13 +1241,31 @@ def main(args): if args.train_text_encoder: text_encoder_one.gradient_checkpointing_enable() text_encoder_two.gradient_checkpointing_enable() + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + target_modules = [ + "attn.add_k_proj", + "attn.add_q_proj", + "attn.add_v_proj", + "attn.to_add_out", + "attn.to_k", + "attn.to_out.0", + "attn.to_q", + "attn.to_v", + ] + if args.lora_blocks is not None: + target_blocks = [int(block.strip()) for block in args.lora_blocks.split(",")] + target_modules = [ + f"transformer_blocks.{block}.{module}" for block in target_blocks for module in target_modules + ] # now we will add new LoRA weights to the attention layers transformer_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, init_lora_weights="gaussian", - target_modules=["to_k", "to_q", "to_v", "to_out.0"], + target_modules=target_modules, ) transformer.add_adapter(transformer_lora_config) From 743a5697f2596567c991e8bc5dd2d4d4a4fffa99 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 28 Oct 2024 17:27:41 +0200 Subject: [PATCH 12/14] [flux dreambooth lora training] make LoRA target modules configurable + small bug fix (#9646) * make lora target modules configurable and change the default * style * make lora target modules configurable and change the default * fix bug when using prodigy and training te * fix mixed precision training as proposed in https://github.com/huggingface/diffusers/pull/9565 for full dreambooth as well * add test and notes * style * address sayaks comments * style * fix test --------- Co-authored-by: Sayak Paul --- examples/dreambooth/README_flux.md | 15 ++++++++ .../dreambooth/test_dreambooth_lora_flux.py | 38 +++++++++++++++++++ examples/dreambooth/train_dreambooth_flux.py | 6 ++- .../dreambooth/train_dreambooth_lora_flux.py | 33 ++++++++++++++-- 4 files changed, 87 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index 69dfd241395b..a724ca53b927 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -170,6 +170,21 @@ accelerate launch train_dreambooth_lora_flux.py \ --push_to_hub ``` +### Target Modules +When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them. +More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore +applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string +the exact modules for LoRA training. Here are some examples of target modules you can provide: +- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"` +- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"` +- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"` +> [!NOTE] +> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string: +> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k` +> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k` +> [!NOTE] +> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights. + ### Text Encoder Training Alongside the transformer, fine-tuning of the CLIP text encoder is also supported. diff --git a/examples/dreambooth/test_dreambooth_lora_flux.py b/examples/dreambooth/test_dreambooth_lora_flux.py index d197c8187b87..a76825e29448 100644 --- a/examples/dreambooth/test_dreambooth_lora_flux.py +++ b/examples/dreambooth/test_dreambooth_lora_flux.py @@ -37,6 +37,7 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate): instance_prompt = "photo" pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe" script_path = "examples/dreambooth/train_dreambooth_lora_flux.py" + transformer_layer_type = "single_transformer_blocks.0.attn.to_k" def test_dreambooth_lora_flux(self): with tempfile.TemporaryDirectory() as tmpdir: @@ -136,6 +137,43 @@ def test_dreambooth_lora_latent_caching(self): starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys()) self.assertTrue(starts_with_transformer) + def test_dreambooth_lora_layers(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + {self.script_path} + --pretrained_model_name_or_path {self.pretrained_model_name_or_path} + --instance_data_dir {self.instance_data_dir} + --instance_prompt {self.instance_prompt} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --cache_latents + --learning_rate 5.0e-04 + --scale_lr + --lora_layers {self.transformer_layer_type} + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"transformer"` in their names. In this test, we only params of + # transformer.single_transformer_blocks.0.attn.to_k should be in the state dict + starts_with_transformer = all( + key.startswith("transformer.single_transformer_blocks.0.attn.to_k") for key in lora_state_dict.keys() + ) + self.assertTrue(starts_with_transformer) + def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index 8ab6f4bb6c30..f720afef6542 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -161,7 +161,7 @@ def log_validation( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) - pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference @@ -1579,7 +1579,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): ) # handle guidance - if transformer.config.guidance_embeds: + if accelerator.unwrap_model(transformer).config.guidance_embeds: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: @@ -1693,6 +1693,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): # create pipeline if not args.train_text_encoder: text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) + text_encoder_one.to(weight_dtype) + text_encoder_two.to(weight_dtype) else: # even when training the text encoder we're only training text encoder one text_encoder_two = text_encoder_cls_two.from_pretrained( args.pretrained_model_name_or_path, diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 5df071b19121..b6e657234850 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -554,6 +554,15 @@ def parse_args(input_args=None): "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) + parser.add_argument( + "--lora_layers", + type=str, + default=None, + help=( + 'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only' + ), + ) + parser.add_argument( "--adam_epsilon", type=float, @@ -1186,12 +1195,30 @@ def main(args): if args.train_text_encoder: text_encoder_one.gradient_checkpointing_enable() - # now we will add new LoRA weights to the attention layers + if args.lora_layers is not None: + target_modules = [layer.strip() for layer in args.lora_layers.split(",")] + else: + target_modules = [ + "attn.to_k", + "attn.to_q", + "attn.to_v", + "attn.to_out.0", + "attn.add_k_proj", + "attn.add_q_proj", + "attn.add_v_proj", + "attn.to_add_out", + "ff.net.0.proj", + "ff.net.2", + "ff_context.net.0.proj", + "ff_context.net.2", + ] + + # now we will add new LoRA weights the transformer layers transformer_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, init_lora_weights="gaussian", - target_modules=["to_k", "to_q", "to_v", "to_out.0"], + target_modules=target_modules, ) transformer.add_adapter(transformer_lora_config) if args.train_text_encoder: @@ -1367,7 +1394,7 @@ def load_model_hook(models, input_dir): f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " f"When using prodigy only learning_rate is used as the initial learning rate." ) - # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be + # changes the learning rate of text_encoder_parameters_one to be # --learning_rate params_to_optimize[1]["lr"] = args.learning_rate From c5376c569572aab09794341c40ce0658dcf98125 Mon Sep 17 00:00:00 2001 From: Raul Ciotescu Date: Mon, 28 Oct 2024 19:48:04 +0100 Subject: [PATCH 13/14] adds the pipeline for pixart alpha controlnet (#8857) * add the controlnet pipeline for pixart alpha --------- Co-authored-by: YiYi Xu Co-authored-by: Sayak Paul Co-authored-by: junsongc --- examples/community/README.md | 92 ++ examples/research_projects/pixart/.gitignore | 2 + .../pixart/controlnet_pixart_alpha.py | 307 +++++ .../pipeline_pixart_alpha_controlnet.py | 1097 +++++++++++++++ .../research_projects/pixart/requirements.txt | 6 + .../run_pixart_alpha_controlnet_pipeline.py | 75 ++ .../pixart/train_controlnet_hf_diffusers.sh | 23 + .../pixart/train_pixart_controlnet_hf.py | 1176 +++++++++++++++++ 8 files changed, 2778 insertions(+) create mode 100644 examples/research_projects/pixart/.gitignore create mode 100644 examples/research_projects/pixart/controlnet_pixart_alpha.py create mode 100644 examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py create mode 100644 examples/research_projects/pixart/requirements.txt create mode 100644 examples/research_projects/pixart/run_pixart_alpha_controlnet_pipeline.py create mode 100755 examples/research_projects/pixart/train_controlnet_hf_diffusers.sh create mode 100644 examples/research_projects/pixart/train_pixart_controlnet_hf.py diff --git a/examples/community/README.md b/examples/community/README.md index 4f16f65df8fa..743993eb44c3 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -73,6 +73,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif | Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) | | FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) | | AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) | +PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) | | HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) | | [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/pcuenq/mdm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) | @@ -4445,3 +4446,94 @@ grid_image.save(grid_dir + "sample.png") `pag_scale` : guidance scale of PAG (ex: 5.0) `pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0']) + +# PIXART-α Controlnet pipeline + +[Project](https://pixart-alpha.github.io/) / [GitHub](https://github.com/PixArt-alpha/PixArt-alpha/blob/master/asset/docs/pixart_controlnet.md) + +This the implementation of the controlnet model and the pipelne for the Pixart-alpha model, adapted to use the HuggingFace Diffusers. + +## Example Usage + +This example uses the Pixart HED Controlnet model, converted from the control net model as trained by the authors of the paper. + +```py +import sys +import os +import torch +import torchvision.transforms as T +import torchvision.transforms.functional as TF + +from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline +from diffusers.utils import load_image + +from diffusers.image_processor import PixArtImageProcessor + +from controlnet_aux import HEDdetector + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel + +controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet" + +weight_dtype = torch.float16 +image_size = 1024 + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +torch.manual_seed(0) + +# load controlnet +controlnet = PixArtControlNetAdapterModel.from_pretrained( + controlnet_repo_id, + torch_dtype=weight_dtype, + use_safetensors=True, +).to(device) + +pipe = PixArtAlphaControlnetPipeline.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + controlnet=controlnet, + torch_dtype=weight_dtype, + use_safetensors=True, +).to(device) + +images_path = "images" +control_image_file = "0_7.jpg" + +prompt = "battleship in space, galaxy in background" + +control_image_name = control_image_file.split('.')[0] + +control_image = load_image(f"{images_path}/{control_image_file}") +print(control_image.size) +height, width = control_image.size + +hed = HEDdetector.from_pretrained("lllyasviel/Annotators") + +condition_transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB')), + T.CenterCrop([image_size, image_size]), +]) + +control_image = condition_transform(control_image) +hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size) + +hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg") + +# run pipeline +with torch.no_grad(): + out = pipe( + prompt=prompt, + image=hed_edge, + num_inference_steps=14, + guidance_scale=4.5, + height=image_size, + width=image_size, + ) + + out.images[0].save(f"{images_path}//{control_image_name}_output.jpg") + +``` + +In the folder examples/pixart there is also a script that can be used to train new models. +Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training. \ No newline at end of file diff --git a/examples/research_projects/pixart/.gitignore b/examples/research_projects/pixart/.gitignore new file mode 100644 index 000000000000..4be0fcb237f5 --- /dev/null +++ b/examples/research_projects/pixart/.gitignore @@ -0,0 +1,2 @@ +images/ +output/ \ No newline at end of file diff --git a/examples/research_projects/pixart/controlnet_pixart_alpha.py b/examples/research_projects/pixart/controlnet_pixart_alpha.py new file mode 100644 index 000000000000..b7f5a427e52e --- /dev/null +++ b/examples/research_projects/pixart/controlnet_pixart_alpha.py @@ -0,0 +1,307 @@ +from typing import Any, Dict, Optional + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models import PixArtTransformer2DModel +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.torch_utils import is_torch_version + + +class PixArtControlNetAdapterBlock(nn.Module): + def __init__( + self, + block_index, + # taken from PixArtTransformer2DModel + num_attention_heads: int = 16, + attention_head_dim: int = 72, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = 1152, + attention_bias: bool = True, + activation_fn: str = "gelu-approximate", + num_embeds_ada_norm: Optional[int] = 1000, + upcast_attention: bool = False, + norm_type: str = "ada_norm_single", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + attention_type: Optional[str] = "default", + ): + super().__init__() + + self.block_index = block_index + self.inner_dim = num_attention_heads * attention_head_dim + + # the first block has a zero before layer + if self.block_index == 0: + self.before_proj = nn.Linear(self.inner_dim, self.inner_dim) + nn.init.zeros_(self.before_proj.weight) + nn.init.zeros_(self.before_proj.bias) + + self.transformer_block = BasicTransformerBlock( + self.inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + + self.after_proj = nn.Linear(self.inner_dim, self.inner_dim) + nn.init.zeros_(self.after_proj.weight) + nn.init.zeros_(self.after_proj.bias) + + def train(self, mode: bool = True): + self.transformer_block.train(mode) + + if self.block_index == 0: + self.before_proj.train(mode) + + self.after_proj.train(mode) + + def forward( + self, + hidden_states: torch.Tensor, + controlnet_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ): + if self.block_index == 0: + controlnet_states = self.before_proj(controlnet_states) + controlnet_states = hidden_states + controlnet_states + + controlnet_states_down = self.transformer_block( + hidden_states=controlnet_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + added_cond_kwargs=added_cond_kwargs, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + class_labels=None, + ) + + controlnet_states_left = self.after_proj(controlnet_states_down) + + return controlnet_states_left, controlnet_states_down + + +class PixArtControlNetAdapterModel(ModelMixin, ConfigMixin): + # N=13, as specified in the paper https://arxiv.org/html/2401.05252v1/#S4 ControlNet-Transformer + @register_to_config + def __init__(self, num_layers=13) -> None: + super().__init__() + + self.num_layers = num_layers + + self.controlnet_blocks = nn.ModuleList( + [PixArtControlNetAdapterBlock(block_index=i) for i in range(num_layers)] + ) + + @classmethod + def from_transformer(cls, transformer: PixArtTransformer2DModel): + control_net = PixArtControlNetAdapterModel() + + # copied the specified number of blocks from the transformer + for depth in range(control_net.num_layers): + control_net.controlnet_blocks[depth].transformer_block.load_state_dict( + transformer.transformer_blocks[depth].state_dict() + ) + + return control_net + + def train(self, mode: bool = True): + for block in self.controlnet_blocks: + block.train(mode) + + +class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin): + def __init__( + self, + transformer: PixArtTransformer2DModel, + controlnet: PixArtControlNetAdapterModel, + blocks_num=13, + init_from_transformer=False, + training=False, + ): + super().__init__() + + self.blocks_num = blocks_num + self.gradient_checkpointing = False + self.register_to_config(**transformer.config) + self.training = training + + if init_from_transformer: + # copies the specified number of blocks from the transformer + controlnet.from_transformer(transformer, self.blocks_num) + + self.transformer = transformer + self.controlnet = controlnet + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + controlnet_cond: Optional[torch.Tensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + if self.transformer.use_additional_conditions and added_cond_kwargs is None: + raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.") + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch_size = hidden_states.shape[0] + height, width = ( + hidden_states.shape[-2] // self.transformer.config.patch_size, + hidden_states.shape[-1] // self.transformer.config.patch_size, + ) + hidden_states = self.transformer.pos_embed(hidden_states) + + timestep, embedded_timestep = self.transformer.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + if self.transformer.caption_projection is not None: + encoder_hidden_states = self.transformer.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + controlnet_states_down = None + if controlnet_cond is not None: + controlnet_states_down = self.transformer.pos_embed(controlnet_cond) + + # 2. Blocks + for block_index, block in enumerate(self.transformer.transformer_blocks): + if self.training and self.gradient_checkpointing: + # rc todo: for training and gradient checkpointing + print("Gradient checkpointing is not supported for the controlnet transformer model, yet.") + exit(1) + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + None, + **ckpt_kwargs, + ) + else: + # the control nets are only used for the blocks 1 to self.blocks_num + if block_index > 0 and block_index <= self.blocks_num and controlnet_states_down is not None: + controlnet_states_left, controlnet_states_down = self.controlnet.controlnet_blocks[ + block_index - 1 + ]( + hidden_states=hidden_states, # used only in the first block + controlnet_states=controlnet_states_down, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + added_cond_kwargs=added_cond_kwargs, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + + hidden_states = hidden_states + controlnet_states_left + + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=None, + ) + + # 3. Output + shift, scale = ( + self.transformer.scale_shift_table[None] + + embedded_timestep[:, None].to(self.transformer.scale_shift_table.device) + ).chunk(2, dim=1) + hidden_states = self.transformer.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device) + hidden_states = self.transformer.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + hidden_states = hidden_states.reshape( + shape=( + -1, + height, + width, + self.transformer.config.patch_size, + self.transformer.config.patch_size, + self.transformer.out_channels, + ) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=( + -1, + self.transformer.out_channels, + height * self.transformer.config.patch_size, + width * self.transformer.config.patch_size, + ) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py new file mode 100644 index 000000000000..aace66f9c18e --- /dev/null +++ b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py @@ -0,0 +1,1097 @@ +# Copyright 2024 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from controlnet_pixart_alpha import PixArtControlNetAdapterModel, PixArtControlNetTransformerModel +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers.image_processor import PipelineImageInput, PixArtImageProcessor +from diffusers.models import AutoencoderKL, PixArtTransformer2DModel +from diffusers.pipelines import DiffusionPipeline, ImagePipelineOutput +from diffusers.schedulers import DPMSolverMultistepScheduler +from diffusers.utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import PixArtAlphaPipeline + + >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too. + >>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "A small cactus with a happy face in the Sahara desert." + >>> image = pipe(prompt).images[0] + ``` +""" + +ASPECT_RATIO_1024_BIN = { + "0.25": [512.0, 2048.0], + "0.28": [512.0, 1856.0], + "0.32": [576.0, 1792.0], + "0.33": [576.0, 1728.0], + "0.35": [576.0, 1664.0], + "0.4": [640.0, 1600.0], + "0.42": [640.0, 1536.0], + "0.48": [704.0, 1472.0], + "0.5": [704.0, 1408.0], + "0.52": [704.0, 1344.0], + "0.57": [768.0, 1344.0], + "0.6": [768.0, 1280.0], + "0.68": [832.0, 1216.0], + "0.72": [832.0, 1152.0], + "0.78": [896.0, 1152.0], + "0.82": [896.0, 1088.0], + "0.88": [960.0, 1088.0], + "0.94": [960.0, 1024.0], + "1.0": [1024.0, 1024.0], + "1.07": [1024.0, 960.0], + "1.13": [1088.0, 960.0], + "1.21": [1088.0, 896.0], + "1.29": [1152.0, 896.0], + "1.38": [1152.0, 832.0], + "1.46": [1216.0, 832.0], + "1.67": [1280.0, 768.0], + "1.75": [1344.0, 768.0], + "2.0": [1408.0, 704.0], + "2.09": [1472.0, 704.0], + "2.4": [1536.0, 640.0], + "2.5": [1600.0, 640.0], + "3.0": [1728.0, 576.0], + "4.0": [2048.0, 512.0], +} + +ASPECT_RATIO_512_BIN = { + "0.25": [256.0, 1024.0], + "0.28": [256.0, 928.0], + "0.32": [288.0, 896.0], + "0.33": [288.0, 864.0], + "0.35": [288.0, 832.0], + "0.4": [320.0, 800.0], + "0.42": [320.0, 768.0], + "0.48": [352.0, 736.0], + "0.5": [352.0, 704.0], + "0.52": [352.0, 672.0], + "0.57": [384.0, 672.0], + "0.6": [384.0, 640.0], + "0.68": [416.0, 608.0], + "0.72": [416.0, 576.0], + "0.78": [448.0, 576.0], + "0.82": [448.0, 544.0], + "0.88": [480.0, 544.0], + "0.94": [480.0, 512.0], + "1.0": [512.0, 512.0], + "1.07": [512.0, 480.0], + "1.13": [544.0, 480.0], + "1.21": [544.0, 448.0], + "1.29": [576.0, 448.0], + "1.38": [576.0, 416.0], + "1.46": [608.0, 416.0], + "1.67": [640.0, 384.0], + "1.75": [672.0, 384.0], + "2.0": [704.0, 352.0], + "2.09": [736.0, 352.0], + "2.4": [768.0, 320.0], + "2.5": [800.0, 320.0], + "3.0": [864.0, 288.0], + "4.0": [1024.0, 256.0], +} + +ASPECT_RATIO_256_BIN = { + "0.25": [128.0, 512.0], + "0.28": [128.0, 464.0], + "0.32": [144.0, 448.0], + "0.33": [144.0, 432.0], + "0.35": [144.0, 416.0], + "0.4": [160.0, 400.0], + "0.42": [160.0, 384.0], + "0.48": [176.0, 368.0], + "0.5": [176.0, 352.0], + "0.52": [176.0, 336.0], + "0.57": [192.0, 336.0], + "0.6": [192.0, 320.0], + "0.68": [208.0, 304.0], + "0.72": [208.0, 288.0], + "0.78": [224.0, 288.0], + "0.82": [224.0, 272.0], + "0.88": [240.0, 272.0], + "0.94": [240.0, 256.0], + "1.0": [256.0, 256.0], + "1.07": [256.0, 240.0], + "1.13": [272.0, 240.0], + "1.21": [272.0, 224.0], + "1.29": [288.0, 224.0], + "1.38": [288.0, 208.0], + "1.46": [304.0, 208.0], + "1.67": [320.0, 192.0], + "1.75": [336.0, 192.0], + "2.0": [352.0, 176.0], + "2.09": [368.0, 176.0], + "2.4": [384.0, 160.0], + "2.5": [400.0, 160.0], + "3.0": [432.0, 144.0], + "4.0": [512.0, 128.0], +} + + +def get_closest_hw(width, height, image_size): + if image_size == 1024: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif image_size == 512: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + else: + raise ValueError("Invalid image size") + + height, width = PixArtImageProcessor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + return width, height + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class PixArtAlphaControlnetPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using PixArt-Alpha. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. PixArt-Alpha uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`PixArtTransformer2DModel`]): + A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: PixArtTransformer2DModel, + controlnet: PixArtControlNetAdapterModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + # change to the controlnet transformer model + transformer = PixArtControlNetTransformerModel(transformer=transformer, controlnet=controlnet) + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + controlnet=controlnet, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 120, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.controlnet.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_steps, + image=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + if image is not None: + self.check_image(image, prompt, prompt_embeds) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance: + image = torch.cat([image] * 2) + + return image + + # based on pipeline_pixart_inpaiting.py + def prepare_image_latents(self, image, device, dtype): + image = image.to(device=device, dtype=dtype) + + image_latents = self.vae.encode(image).latent_dist.sample() + image_latents = image_latents * self.vae.config.scaling_factor + return image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.5, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = None, + width: Optional[int] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + # rc todo: controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + # rc todo: control_guidance_start = 0.0, + # rc todo: control_guidance_end = 1.0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: int = 1, + clean_caption: bool = True, + use_resolution_binning: bool = True, + max_sequence_length: int = 120, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + # 1. Check inputs. Raise error if not correct + height = height or self.transformer.config.sample_size * self.vae_scale_factor + width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_256_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_steps, + image, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 4.1 Prepare image + image_latents = None + if image is not None: + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.transformer.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + image_latents = self.prepare_image_latents(image, device, self.transformer.controlnet.dtype) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if self.transformer.config.sample_size == 128: + resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) + aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) + resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) + aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) + + if do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + controlnet_cond=image_latents, + # rc todo: controlnet_conditioning_scale=1.0, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + else: + noise_pred = noise_pred + + # compute previous image: x_t -> x_t-1 + if num_inference_steps == 1: + # For DMD one step sampling: https://arxiv.org/abs/2311.18828 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample + else: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height) + else: + image = latents + + if not output_type == "latent": + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/examples/research_projects/pixart/requirements.txt b/examples/research_projects/pixart/requirements.txt new file mode 100644 index 000000000000..2b307927ee9f --- /dev/null +++ b/examples/research_projects/pixart/requirements.txt @@ -0,0 +1,6 @@ +transformers +SentencePiece +torchvision +controlnet-aux +datasets +# wandb \ No newline at end of file diff --git a/examples/research_projects/pixart/run_pixart_alpha_controlnet_pipeline.py b/examples/research_projects/pixart/run_pixart_alpha_controlnet_pipeline.py new file mode 100644 index 000000000000..0014c590541b --- /dev/null +++ b/examples/research_projects/pixart/run_pixart_alpha_controlnet_pipeline.py @@ -0,0 +1,75 @@ +import torch +import torchvision.transforms as T +from controlnet_aux import HEDdetector + +from diffusers.utils import load_image +from examples.research_projects.pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel +from examples.research_projects.pixart.pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline + + +controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet" + +weight_dtype = torch.float16 +image_size = 1024 + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +torch.manual_seed(0) + +# load controlnet +controlnet = PixArtControlNetAdapterModel.from_pretrained( + controlnet_repo_id, + torch_dtype=weight_dtype, + use_safetensors=True, +).to(device) + +pipe = PixArtAlphaControlnetPipeline.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + controlnet=controlnet, + torch_dtype=weight_dtype, + use_safetensors=True, +).to(device) + +images_path = "images" +control_image_file = "0_7.jpg" + +# prompt = "cinematic photo of superman in action . 35mm photograph, film, bokeh, professional, 4k, highly detailed" +# prompt = "yellow modern car, city in background, beautiful rainy day" +# prompt = "modern villa, clear sky, suny day . 35mm photograph, film, bokeh, professional, 4k, highly detailed" +# prompt = "robot dog toy in park . 35mm photograph, film, bokeh, professional, 4k, highly detailed" +# prompt = "purple car, on highway, beautiful sunny day" +# prompt = "realistical photo of a loving couple standing in the open kitchen of the living room, cooking ." +prompt = "battleship in space, galaxy in background" + +control_image_name = control_image_file.split(".")[0] + +control_image = load_image(f"{images_path}/{control_image_file}") +print(control_image.size) +height, width = control_image.size + +hed = HEDdetector.from_pretrained("lllyasviel/Annotators") + +condition_transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB")), + T.CenterCrop([image_size, image_size]), + ] +) + +control_image = condition_transform(control_image) +hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size) + +hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg") + +# run pipeline +with torch.no_grad(): + out = pipe( + prompt=prompt, + image=hed_edge, + num_inference_steps=14, + guidance_scale=4.5, + height=image_size, + width=image_size, + ) + + out.images[0].save(f"{images_path}//{control_image_name}_output.jpg") diff --git a/examples/research_projects/pixart/train_controlnet_hf_diffusers.sh b/examples/research_projects/pixart/train_controlnet_hf_diffusers.sh new file mode 100755 index 000000000000..0abd88f19e18 --- /dev/null +++ b/examples/research_projects/pixart/train_controlnet_hf_diffusers.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +# run +# accelerate config + +# check with +# accelerate env + +export MODEL_DIR="PixArt-alpha/PixArt-XL-2-512x512" +export OUTPUT_DIR="output/pixart-controlnet-hf-diffusers-test" + +accelerate launch ./train_pixart_controlnet_hf.py --mixed_precision="fp16" \ + --pretrained_model_name_or_path=$MODEL_DIR \ + --output_dir=$OUTPUT_DIR \ + --dataset_name=fusing/fill50k \ + --resolution=512 \ + --learning_rate=1e-5 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --report_to="wandb" \ + --seed=42 \ + --dataloader_num_workers=8 +# --lr_scheduler="cosine" --lr_warmup_steps=0 \ diff --git a/examples/research_projects/pixart/train_pixart_controlnet_hf.py b/examples/research_projects/pixart/train_pixart_controlnet_hf.py new file mode 100644 index 000000000000..995a20dfa28e --- /dev/null +++ b/examples/research_projects/pixart/train_pixart_controlnet_hf.py @@ -0,0 +1,1176 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fine-tuning script for Stable Diffusion for text2image with HuggingFace diffusers.""" + +import argparse +import gc +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import accelerate +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import T5EncoderModel, T5Tokenizer + +import diffusers +from diffusers import AutoencoderKL, DDPMScheduler +from diffusers.models import PixArtTransformer2DModel +from diffusers.optimization import get_scheduler +from diffusers.training_utils import compute_snr +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module +from examples.research_projects.pixart.controlnet_pixart_alpha import ( + PixArtControlNetAdapterModel, + PixArtControlNetTransformerModel, +) + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.29.2") + +logger = get_logger(__name__, log_level="INFO") + + +def log_validation( + vae, + transformer, + controlnet, + tokenizer, + scheduler, + text_encoder, + args, + accelerator, + weight_dtype, + step, + is_final_validation=False, +): + if weight_dtype == torch.float16 or weight_dtype == torch.bfloat16: + raise ValueError( + "Validation is not supported with mixed precision training, disable validation and use the validation script, that will generate images from the saved checkpoints." + ) + + if not is_final_validation: + logger.info(f"Running validation step {step} ... ") + + controlnet = accelerator.unwrap_model(controlnet) + pipeline = PixArtAlphaControlnetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + transformer=transformer, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + controlnet=controlnet, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + else: + logger.info("Running validation - final ... ") + + controlnet = PixArtControlNetAdapterModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) + + pipeline = PixArtAlphaControlnetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + controlnet=controlnet, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) + + image_logs = [] + + for validation_prompt, validation_image in zip(validation_prompts, validation_images): + validation_image = Image.open(validation_image).convert("RGB") + validation_image = validation_image.resize((args.resolution, args.resolution)) + + images = [] + + for _ in range(args.num_validation_images): + image = pipeline( + prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator + ).images[0] + images.append(image) + + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images = [] + + formatted_images.append(np.asarray(validation_image)) + + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({tracker_key: formatted_images}) + else: + logger.warning(f"image logging not implemented for {tracker.name}") + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + logger.info("Validation done!!") + + return image_logs + + +def save_model_card(repo_id: str, image_logs=None, base_model=str, dataset_name=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + model_description = f""" +# controlnet-{repo_id} + +These are controlnet weights trained on {base_model} with new type of conditioning. +{img_str} +""" + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="openrail++", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = [ + "pixart-alpha", + "pixart-alpha-diffusers", + "text-to-image", + "diffusers", + "controlnet", + "diffusers-training", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained controlnet model or model identifier from huggingface.co/models." + " If not specified controlnet weights are initialized from the transformer.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + nargs="+", + default=None, + help="One or more prompts to be evaluated every `--validation_steps`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.", + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="pixart-controlnet", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + # ----Diffusion Training Arguments---- + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + + parser.add_argument( + "--tracker_project_name", + type=str, + default="pixart_controlnet", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + + # Sanity checks + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Need either a dataset name or a training folder.") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + return args + + +def main(): + args = parse_args() + + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # See Section 3.1. of the paper. + max_length = 120 + + # For mixed precision training we cast all non-trainable weigths (vae, text_encoder) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler", torch_dtype=weight_dtype + ) + tokenizer = T5Tokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, torch_dtype=weight_dtype + ) + + text_encoder = T5EncoderModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, torch_dtype=weight_dtype + ) + text_encoder.requires_grad_(False) + text_encoder.to(accelerator.device) + + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + vae.requires_grad_(False) + vae.to(accelerator.device) + + transformer = PixArtTransformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="transformer") + transformer.to(accelerator.device) + transformer.requires_grad_(False) + + if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + controlnet = PixArtControlNetAdapterModel.from_pretrained(args.controlnet_model_name_or_path) + else: + logger.info("Initializing controlnet weights from transformer.") + controlnet = PixArtControlNetAdapterModel.from_transformer(transformer) + + transformer.to(dtype=weight_dtype) + + controlnet.to(accelerator.device) + controlnet.train() + + def unwrap_model(model, keep_fp32_wrapper=True): + model = accelerator.unwrap_model(model, keep_fp32_wrapper=keep_fp32_wrapper) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # 10. Handle saving and loading of checkpoints + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for _, model in enumerate(models): + if isinstance(model, PixArtControlNetTransformerModel): + print(f"Saving model {model.__class__.__name__} to {output_dir}") + model.controlnet.save_pretrained(os.path.join(output_dir, "controlnet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + # rc todo: test and load the controlenet adapter and transformer + raise ValueError("load model hook not tested") + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + if isinstance(model, PixArtControlNetTransformerModel): + load_model = PixArtControlNetAdapterModel.from_pretrained(input_dir, subfolder="controlnet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + transformer.enable_xformers_memory_efficient_attention() + controlnet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if unwrap_model(controlnet).dtype != torch.float32: + raise ValueError( + f"Transformer loaded as datatype {unwrap_model(controlnet).dtype}. The trainable parameters should be in torch.float32." + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.gradient_checkpointing: + transformer.enable_gradient_checkpointing() + controlnet.enable_gradient_checkpointing() + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + params_to_optimize = controlnet.parameters() + optimizer = optimizer_cls( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + data_dir=args.train_data_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {conditioning_image_column}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + # Preprocessing the datasets. + # We need to tokenize input captions and transform the images. + def tokenize_captions(examples, is_train=True, proportion_empty_prompts=0.0, max_length=120): + captions = [] + for caption in examples[caption_column]: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer(captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt") + return inputs.input_ids, inputs.attention_mask + + # Preprocessing the datasets. + train_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + conditioning_image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["pixel_values"] = [train_transforms(image) for image in images] + + conditioning_images = [image.convert("RGB") for image in examples[args.conditioning_image_column]] + examples["conditioning_pixel_values"] = [conditioning_image_transforms(image) for image in conditioning_images] + + examples["input_ids"], examples["prompt_attention_mask"] = tokenize_captions( + examples, proportion_empty_prompts=args.proportion_empty_prompts, max_length=max_length + ) + + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.stack([example["input_ids"] for example in examples]) + prompt_attention_mask = torch.stack([example["prompt_attention_mask"] for example in examples]) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "input_ids": input_ids, + "prompt_attention_mask": prompt_attention_mask, + } + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + ) + + # Prepare everything with our `accelerator`. + controlnet_transformer = PixArtControlNetTransformerModel(transformer, controlnet, training=True) + controlnet_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet_transformer, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers(args.tracker_project_name, config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + latent_channels = transformer.config.in_channels + for epoch in range(first_epoch, args.num_train_epochs): + controlnet_transformer.controlnet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(controlnet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * vae.config.scaling_factor + + # Convert control images to latent space + controlnet_image_latents = vae.encode( + batch["conditioning_pixel_values"].to(dtype=weight_dtype) + ).latent_dist.sample() + controlnet_image_latents = controlnet_image_latents * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + ) + + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + prompt_embeds = text_encoder(batch["input_ids"], attention_mask=batch["prompt_attention_mask"])[0] + prompt_attention_mask = batch["prompt_attention_mask"] + # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # Prepare micro-conditions. + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if getattr(transformer, "module", transformer).config.sample_size == 128: + resolution = torch.tensor([args.resolution, args.resolution]).repeat(bsz, 1) + aspect_ratio = torch.tensor([float(args.resolution / args.resolution)]).repeat(bsz, 1) + resolution = resolution.to(dtype=weight_dtype, device=latents.device) + aspect_ratio = aspect_ratio.to(dtype=weight_dtype, device=latents.device) + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + # Predict the noise residual and compute loss + model_pred = controlnet_transformer( + noisy_latents, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timesteps, + controlnet_cond=controlnet_image_latents, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if transformer.config.out_channels // 2 == latent_channels: + model_pred = model_pred.chunk(2, dim=1)[0] + else: + model_pred = model_pred + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + train_loss += avg_loss.item() / args.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = controlnet_transformer.controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + + logger.info(f"Saved state to {save_path}") + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + log_validation( + vae, + transformer, + controlnet_transformer.controlnet, + tokenizer, + noise_scheduler, + text_encoder, + args, + accelerator, + weight_dtype, + global_step, + is_final_validation=False, + ) + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + controlnet = unwrap_model(controlnet_transformer.controlnet, keep_fp32_wrapper=False) + controlnet.save_pretrained(os.path.join(args.output_dir, "controlnet")) + + image_logs = None + if args.validation_prompt is not None: + image_logs = log_validation( + vae, + transformer, + controlnet, + tokenizer, + noise_scheduler, + text_encoder, + args, + accelerator, + weight_dtype, + global_step, + is_final_validation=True, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + dataset_name=args.dataset_name, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + main() From 0d1d267b12e47b40b0e8f265339c76e0f45f8c49 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 29 Oct 2024 13:14:36 +0530 Subject: [PATCH 14/14] [core] Allegro T2V (#9736) * update * refactor transformer part 1 * refactor part 2 * refactor part 3 * make style * refactor part 4; modeling tests * make style * refactor part 5 * refactor part 6 * gradient checkpointing * pipeline tests (broken atm) * update * add coauthor Co-Authored-By: Huan Yang * refactor part 7 * add docs * make style * add coauthor Co-Authored-By: YiYi Xu * make fix-copies * undo unrelated change * revert changes to embeddings, normalization, transformer * refactor part 8 * make style * refactor part 9 * make style * fix * apply suggestions from review * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * update example * remove attention mask for self-attention * update * copied from * update * update --------- Co-authored-by: Huan Yang Co-authored-by: YiYi Xu Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/_toctree.yml | 6 + .../en/api/models/allegro_transformer3d.md | 30 + .../en/api/models/autoencoderkl_allegro.md | 37 + docs/source/en/api/pipelines/allegro.md | 34 + src/diffusers/__init__.py | 6 + src/diffusers/models/__init__.py | 4 + src/diffusers/models/attention_processor.py | 94 ++ src/diffusers/models/autoencoders/__init__.py | 1 + .../autoencoders/autoencoder_kl_allegro.py | 1155 +++++++++++++++++ src/diffusers/models/embeddings.py | 56 +- src/diffusers/models/normalization.py | 6 +- src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_allegro.py | 422 ++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/allegro/__init__.py | 48 + .../pipelines/allegro/pipeline_allegro.py | 918 +++++++++++++ .../pipelines/allegro/pipeline_output.py | 23 + src/diffusers/utils/dummy_pt_objects.py | 30 + .../dummy_torch_and_transformers_objects.py | 15 + .../test_models_transformer_allegro.py | 79 ++ tests/pipelines/allegro/__init__.py | 0 tests/pipelines/allegro/test_allegro.py | 337 +++++ tests/pipelines/test_pipelines_common.py | 1 + 23 files changed, 3300 insertions(+), 5 deletions(-) create mode 100644 docs/source/en/api/models/allegro_transformer3d.md create mode 100644 docs/source/en/api/models/autoencoderkl_allegro.md create mode 100644 docs/source/en/api/pipelines/allegro.md create mode 100644 src/diffusers/models/autoencoders/autoencoder_kl_allegro.py create mode 100644 src/diffusers/models/transformers/transformer_allegro.py create mode 100644 src/diffusers/pipelines/allegro/__init__.py create mode 100644 src/diffusers/pipelines/allegro/pipeline_allegro.py create mode 100644 src/diffusers/pipelines/allegro/pipeline_output.py create mode 100644 tests/models/transformers/test_models_transformer_allegro.py create mode 100644 tests/pipelines/allegro/__init__.py create mode 100644 tests/pipelines/allegro/test_allegro.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 87ff9b1fb81a..c0d571a5864d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -252,6 +252,8 @@ title: SparseControlNetModel title: ControlNets - sections: + - local: api/models/allegro_transformer3d + title: AllegroTransformer3DModel - local: api/models/aura_flow_transformer2d title: AuraFlowTransformer2DModel - local: api/models/cogvideox_transformer3d @@ -300,6 +302,8 @@ - sections: - local: api/models/autoencoderkl title: AutoencoderKL + - local: api/models/autoencoderkl_allegro + title: AutoencoderKLAllegro - local: api/models/autoencoderkl_cogvideox title: AutoencoderKLCogVideoX - local: api/models/asymmetricautoencoderkl @@ -318,6 +322,8 @@ sections: - local: api/pipelines/overview title: Overview + - local: api/pipelines/allegro + title: Allegro - local: api/pipelines/amused title: aMUSEd - local: api/pipelines/animatediff diff --git a/docs/source/en/api/models/allegro_transformer3d.md b/docs/source/en/api/models/allegro_transformer3d.md new file mode 100644 index 000000000000..e70026fe4bfc --- /dev/null +++ b/docs/source/en/api/models/allegro_transformer3d.md @@ -0,0 +1,30 @@ + + +# AllegroTransformer3DModel + +A Diffusion Transformer model for 3D data from [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AllegroTransformer3DModel + +vae = AllegroTransformer3DModel.from_pretrained("rhymes-ai/Allegro", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") +``` + +## AllegroTransformer3DModel + +[[autodoc]] AllegroTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/models/autoencoderkl_allegro.md b/docs/source/en/api/models/autoencoderkl_allegro.md new file mode 100644 index 000000000000..fd9d10d5724b --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_allegro.md @@ -0,0 +1,37 @@ + + +# AutoencoderKLAllegro + +The 3D variational autoencoder (VAE) model with KL loss used in [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLAllegro + +vae = AutoencoderKLCogVideoX.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32).to("cuda") +``` + +## AutoencoderKLAllegro + +[[autodoc]] AutoencoderKLAllegro + - decode + - encode + - all + +## AutoencoderKLOutput + +[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/pipelines/allegro.md b/docs/source/en/api/pipelines/allegro.md new file mode 100644 index 000000000000..e13e339944e5 --- /dev/null +++ b/docs/source/en/api/pipelines/allegro.md @@ -0,0 +1,34 @@ + + +# Allegro + +[Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) from RhymesAI, by Yuan Zhou, Qiuyue Wang, Yuxuan Cai, Huan Yang. + +The abstract from the paper is: + +*Significant advancements have been made in the field of video generation, with the open-source community contributing a wealth of research papers and tools for training high-quality models. However, despite these efforts, the available information and resources remain insufficient for achieving commercial-level performance. In this report, we open the black box and introduce Allegro, an advanced video generation model that excels in both quality and temporal consistency. We also highlight the current limitations in the field and present a comprehensive methodology for training high-performance, commercial-level video generation models, addressing key aspects such as data, model architecture, training pipeline, and evaluation. Our user study shows that Allegro surpasses existing open-source models and most commercial models, ranking just behind Hailuo and Kling. Code: https://github.com/rhymes-ai/Allegro , Model: https://huggingface.co/rhymes-ai/Allegro , Gallery: https://rhymes.ai/allegro_gallery .* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +## AllegroPipeline + +[[autodoc]] AllegroPipeline + - all + - __call__ + +## AllegroPipelineOutput + +[[autodoc]] pipelines.allegro.pipeline_output.AllegroPipelineOutput diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 789458a26299..ff59a3839552 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -77,9 +77,11 @@ else: _import_structure["models"].extend( [ + "AllegroTransformer3DModel", "AsymmetricAutoencoderKL", "AuraFlowTransformer2DModel", "AutoencoderKL", + "AutoencoderKLAllegro", "AutoencoderKLCogVideoX", "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", @@ -237,6 +239,7 @@ else: _import_structure["pipelines"].extend( [ + "AllegroPipeline", "AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline", "AmusedImg2ImgPipeline", @@ -556,9 +559,11 @@ from .utils.dummy_pt_objects import * # noqa F403 else: from .models import ( + AllegroTransformer3DModel, AsymmetricAutoencoderKL, AuraFlowTransformer2DModel, AutoencoderKL, + AutoencoderKLAllegro, AutoencoderKLCogVideoX, AutoencoderKLTemporalDecoder, AutoencoderOobleck, @@ -697,6 +702,7 @@ from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipelines import ( + AllegroPipeline, AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, AmusedImg2ImgPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 4dda8c36ba1c..38dd2819133d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -28,6 +28,7 @@ _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] + _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] @@ -54,6 +55,7 @@ _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] + _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] @@ -81,6 +83,7 @@ from .autoencoders import ( AsymmetricAutoencoderKL, AutoencoderKL, + AutoencoderKLAllegro, AutoencoderKLCogVideoX, AutoencoderKLTemporalDecoder, AutoencoderOobleck, @@ -97,6 +100,7 @@ from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( + AllegroTransformer3DModel, AuraFlowTransformer2DModel, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e735c4ee7d17..db88ecbbb9d3 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1521,6 +1521,100 @@ def __call__( return hidden_states, encoder_hidden_states +class AllegroAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # Apply RoPE if needed + if image_rotary_emb is not None and not attn.is_cross_attention: + from .embeddings import apply_rotary_emb_allegro + + query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1]) + key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1]) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class AuraFlowAttnProcessor2_0: """Attention processor used typically in processing Aura Flow.""" diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index ccf4552b2a5e..9628fe7f21b0 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -1,5 +1,6 @@ from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_kl import AutoencoderKL +from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_oobleck import AutoencoderOobleck diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py new file mode 100644 index 000000000000..4836de7e16ab --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -0,0 +1,1155 @@ +# Copyright 2024 The RhymesAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.accelerate_utils import apply_forward_hook +from ..attention_processor import Attention, SpatialNorm +from ..autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution +from ..downsampling import Downsample2D +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..resnet import ResnetBlock2D +from ..upsampling import Upsample2D + + +class AllegroTemporalConvLayer(nn.Module): + r""" + Temporal convolutional layer that can be used for video (sequence of images) input. Code adapted from: + https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 + """ + + def __init__( + self, + in_dim: int, + out_dim: Optional[int] = None, + dropout: float = 0.0, + norm_num_groups: int = 32, + up_sample: bool = False, + down_sample: bool = False, + stride: int = 1, + ) -> None: + super().__init__() + + out_dim = out_dim or in_dim + pad_h = pad_w = int((stride - 1) * 0.5) + pad_t = 0 + + self.down_sample = down_sample + self.up_sample = up_sample + + if down_sample: + self.conv1 = nn.Sequential( + nn.GroupNorm(norm_num_groups, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (2, stride, stride), stride=(2, 1, 1), padding=(0, pad_h, pad_w)), + ) + elif up_sample: + self.conv1 = nn.Sequential( + nn.GroupNorm(norm_num_groups, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim * 2, (1, stride, stride), padding=(0, pad_h, pad_w)), + ) + else: + self.conv1 = nn.Sequential( + nn.GroupNorm(norm_num_groups, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)), + ) + self.conv2 = nn.Sequential( + nn.GroupNorm(norm_num_groups, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)), + ) + self.conv3 = nn.Sequential( + nn.GroupNorm(norm_num_groups, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)), + ) + self.conv4 = nn.Sequential( + nn.GroupNorm(norm_num_groups, out_dim), + nn.SiLU(), + nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)), + ) + + @staticmethod + def _pad_temporal_dim(hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2) + return hidden_states + + def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor: + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + if self.down_sample: + identity = hidden_states[:, :, ::2] + elif self.up_sample: + identity = hidden_states.repeat_interleave(2, dim=2) + else: + identity = hidden_states + + if self.down_sample or self.up_sample: + hidden_states = self.conv1(hidden_states) + else: + hidden_states = self._pad_temporal_dim(hidden_states) + hidden_states = self.conv1(hidden_states) + + if self.up_sample: + hidden_states = hidden_states.unflatten(1, (2, -1)).permute(0, 2, 3, 1, 4, 5).flatten(2, 3) + + hidden_states = self._pad_temporal_dim(hidden_states) + hidden_states = self.conv2(hidden_states) + + hidden_states = self._pad_temporal_dim(hidden_states) + hidden_states = self.conv3(hidden_states) + + hidden_states = self._pad_temporal_dim(hidden_states) + hidden_states = self.conv4(hidden_states) + + hidden_states = identity + hidden_states + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + + return hidden_states + + +class AllegroDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + spatial_downsample: bool = True, + temporal_downsample: bool = False, + downsample_padding: int = 1, + ): + super().__init__() + + resnets = [] + temp_convs = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + AllegroTemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if temporal_downsample: + self.temp_convs_down = AllegroTemporalConvLayer( + out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, down_sample=True, stride=3 + ) + self.add_temp_downsample = temporal_downsample + + if spatial_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = temp_conv(hidden_states, batch_size=batch_size) + + if self.add_temp_downsample: + hidden_states = self.temp_convs_down(hidden_states, batch_size=batch_size) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return hidden_states + + +class AllegroUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + temb_channels: Optional[int] = None, + ): + super().__init__() + + resnets = [] + temp_convs = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + AllegroTemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + self.add_temp_upsample = temporal_upsample + if temporal_upsample: + self.temp_conv_up = AllegroTemporalConvLayer( + out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, up_sample=True, stride=3 + ) + + if spatial_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = temp_conv(hidden_states, batch_size=batch_size) + + if self.add_temp_upsample: + hidden_states = self.temp_conv_up(hidden_states, batch_size=batch_size) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return hidden_states + + +class AllegroMidBlock3DConv(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + AllegroTemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ] + attentions = [] + + if attention_head_dim is None: + attention_head_dim = in_channels + + for _ in range(num_layers): + if add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + temp_convs.append( + AllegroTemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.resnets[0](hidden_states, temb=None) + + hidden_states = self.temp_convs[0](hidden_states, batch_size=batch_size) + + for attn, resnet, temp_conv in zip(self.attentions, self.resnets[1:], self.temp_convs[1:]): + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, temb=None) + hidden_states = temp_conv(hidden_states, batch_size=batch_size) + + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return hidden_states + + +class AllegroEncoder3D(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ( + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + temporal_downsample_blocks: Tuple[bool, ...] = [True, True, False, False], + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + ): + super().__init__() + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.temp_conv_in = nn.Conv3d( + in_channels=block_out_channels[0], + out_channels=block_out_channels[0], + kernel_size=(3, 1, 1), + padding=(1, 0, 0), + ) + + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "AllegroDownBlock3D": + down_block = AllegroDownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + spatial_downsample=not is_final_block, + temporal_downsample=temporal_downsample_blocks[i], + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + ) + else: + raise ValueError("Invalid `down_block_type` encountered. Must be `AllegroDownBlock3D`") + + self.down_blocks.append(down_block) + + # mid + self.mid_block = AllegroMidBlock3DConv( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + + self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0)) + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + batch_size = sample.shape[0] + + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_in(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + residual = sample + sample = self.temp_conv_in(sample) + sample = sample + residual + + if self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # Down blocks + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + + # Mid block + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + else: + # Down blocks + for down_block in self.down_blocks: + sample = down_block(sample) + + # Mid block + sample = self.mid_block(sample) + + # Post process + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + residual = sample + sample = self.temp_conv_out(sample) + sample = sample + residual + + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_out(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return sample + + +class AllegroDecoder3D(nn.Module): + def __init__( + self, + in_channels: int = 4, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ( + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + ), + temporal_upsample_blocks: Tuple[bool, ...] = [False, True, True, False], + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + ): + super().__init__() + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.temp_conv_in = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0)) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = AllegroMidBlock3DConv( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + if up_block_type == "AllegroUpBlock3D": + up_block = AllegroUpBlock3D( + num_layers=layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + spatial_upsample=not is_final_block, + temporal_upsample=temporal_upsample_blocks[i], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + else: + raise ValueError("Invalid `UP_block_type` encountered. Must be `AllegroUpBlock3D`") + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + + self.conv_act = nn.SiLU() + + self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3, 1, 1), padding=(1, 0, 0)) + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + batch_size = sample.shape[0] + + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_in(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + residual = sample + sample = self.temp_conv_in(sample) + sample = sample + residual + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + if self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # Mid block + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + + # Up blocks + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) + + else: + # Mid block + sample = self.mid_block(sample) + sample = sample.to(upscale_dtype) + + # Up blocks + for up_block in self.up_blocks: + sample = up_block(sample) + + # Post process + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + residual = sample + sample = self.temp_conv_out(sample) + sample = sample + residual + + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_out(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return sample + + +class AutoencoderKLAllegro(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in + [Allegro](https://github.com/rhymes-ai/Allegro). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, defaults to `3`): + Number of channels in the input image. + out_channels (int, defaults to `3`): + Number of channels in the output. + down_block_types (`Tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`): + Tuple of strings denoting which types of down blocks to use. + up_block_types (`Tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`): + Tuple of strings denoting which types of up blocks to use. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + Tuple of integers denoting number of output channels in each block. + temporal_downsample_blocks (`Tuple[bool, ...]`, defaults to `(True, True, False, False)`): + Tuple of booleans denoting which blocks to enable temporal downsampling in. + latent_channels (`int`, defaults to `4`): + Number of channels in latents. + layers_per_block (`int`, defaults to `2`): + Number of resnet or attention or temporal convolution layers per down/up block. + act_fn (`str`, defaults to `"silu"`): + The activation function to use. + norm_num_groups (`int`, defaults to `32`): + Number of groups to use in normalization layers. + temporal_compression_ratio (`int`, defaults to `4`): + Ratio by which temporal dimension of samples are compressed. + sample_size (`int`, defaults to `320`): + Default latent size. + scaling_factor (`float`, defaults to `0.13235`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ( + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + ), + up_block_types: Tuple[str, ...] = ( + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + temporal_downsample_blocks: Tuple[bool, ...] = (True, True, False, False), + temporal_upsample_blocks: Tuple[bool, ...] = (False, True, True, False), + latent_channels: int = 4, + layers_per_block: int = 2, + act_fn: str = "silu", + norm_num_groups: int = 32, + temporal_compression_ratio: float = 4, + sample_size: int = 320, + scaling_factor: float = 0.13, + force_upcast: bool = True, + ) -> None: + super().__init__() + + self.encoder = AllegroEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + temporal_downsample_blocks=temporal_downsample_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + self.decoder = AllegroDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + temporal_upsample_blocks=temporal_upsample_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + + # TODO(aryan): For the 1.0.0 refactor, `temporal_compression_ratio` can be inferred directly and we don't need + # to use a specific parameter here or in other VAEs. + + self.use_slicing = False + self.use_tiling = False + + self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1) + self.tile_overlap_t = 8 + self.tile_overlap_h = 120 + self.tile_overlap_w = 80 + sample_frames = 24 + + self.kernel = (sample_frames, sample_size, sample_size) + self.stride = ( + sample_frames - self.tile_overlap_t, + sample_size - self.tile_overlap_h, + sample_size - self.tile_overlap_w, + ) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)): + module.gradient_checkpointing = value + + def enable_tiling(self) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = True + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + # TODO(aryan) + # if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + if self.use_tiling: + return self.tiled_encode(x) + + raise NotImplementedError("Encoding without tiling has not been implemented yet.") + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of videos into latents. + + Args: + x (`torch.Tensor`): + Input batch of videos. + return_dict (`bool`, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + # TODO(aryan): refactor tiling implementation + # if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): + if self.use_tiling: + return self.tiled_decode(z) + + raise NotImplementedError("Decoding without tiling has not been implemented yet.") + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of videos. + + Args: + z (`torch.Tensor`): + Input batch of latent vectors. + return_dict (`bool`, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + local_batch_size = 1 + rs = self.spatial_compression_ratio + rt = self.config.temporal_compression_ratio + + batch_size, num_channels, num_frames, height, width = x.shape + + output_num_frames = math.floor((num_frames - self.kernel[0]) / self.stride[0]) + 1 + output_height = math.floor((height - self.kernel[1]) / self.stride[1]) + 1 + output_width = math.floor((width - self.kernel[2]) / self.stride[2]) + 1 + + count = 0 + output_latent = x.new_zeros( + ( + output_num_frames * output_height * output_width, + 2 * self.config.latent_channels, + self.kernel[0] // rt, + self.kernel[1] // rs, + self.kernel[2] // rs, + ) + ) + vae_batch_input = x.new_zeros((local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2])) + + for i in range(output_num_frames): + for j in range(output_height): + for k in range(output_width): + n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0] + h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1] + w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2] + + video_cube = x[:, :, n_start:n_end, h_start:h_end, w_start:w_end] + vae_batch_input[count % local_batch_size] = video_cube + + if ( + count % local_batch_size == local_batch_size - 1 + or count == output_num_frames * output_height * output_width - 1 + ): + latent = self.encoder(vae_batch_input) + + if ( + count == output_num_frames * output_height * output_width - 1 + and count % local_batch_size != local_batch_size - 1 + ): + output_latent[count - count % local_batch_size :] = latent[: count % local_batch_size + 1] + else: + output_latent[count - local_batch_size + 1 : count + 1] = latent + + vae_batch_input = x.new_zeros( + (local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2]) + ) + + count += 1 + + latent = x.new_zeros( + (batch_size, 2 * self.config.latent_channels, num_frames // rt, height // rs, width // rs) + ) + output_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs + output_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs + output_overlap = ( + output_kernel[0] - output_stride[0], + output_kernel[1] - output_stride[1], + output_kernel[2] - output_stride[2], + ) + + for i in range(output_num_frames): + n_start, n_end = i * output_stride[0], i * output_stride[0] + output_kernel[0] + for j in range(output_height): + h_start, h_end = j * output_stride[1], j * output_stride[1] + output_kernel[1] + for k in range(output_width): + w_start, w_end = k * output_stride[2], k * output_stride[2] + output_kernel[2] + latent_mean = _prepare_for_blend( + (i, output_num_frames, output_overlap[0]), + (j, output_height, output_overlap[1]), + (k, output_width, output_overlap[2]), + output_latent[i * output_height * output_width + j * output_width + k].unsqueeze(0), + ) + latent[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean + + latent = latent.permute(0, 2, 1, 3, 4).flatten(0, 1) + latent = self.quant_conv(latent) + latent = latent.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return latent + + def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + local_batch_size = 1 + rs = self.spatial_compression_ratio + rt = self.config.temporal_compression_ratio + + latent_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs + latent_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs + + batch_size, num_channels, num_frames, height, width = z.shape + + ## post quant conv (a mapping) + z = z.permute(0, 2, 1, 3, 4).flatten(0, 1) + z = self.post_quant_conv(z) + z = z.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + output_num_frames = math.floor((num_frames - latent_kernel[0]) / latent_stride[0]) + 1 + output_height = math.floor((height - latent_kernel[1]) / latent_stride[1]) + 1 + output_width = math.floor((width - latent_kernel[2]) / latent_stride[2]) + 1 + + count = 0 + decoded_videos = z.new_zeros( + ( + output_num_frames * output_height * output_width, + self.config.out_channels, + self.kernel[0], + self.kernel[1], + self.kernel[2], + ) + ) + vae_batch_input = z.new_zeros( + (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2]) + ) + + for i in range(output_num_frames): + for j in range(output_height): + for k in range(output_width): + n_start, n_end = i * latent_stride[0], i * latent_stride[0] + latent_kernel[0] + h_start, h_end = j * latent_stride[1], j * latent_stride[1] + latent_kernel[1] + w_start, w_end = k * latent_stride[2], k * latent_stride[2] + latent_kernel[2] + + current_latent = z[:, :, n_start:n_end, h_start:h_end, w_start:w_end] + vae_batch_input[count % local_batch_size] = current_latent + + if ( + count % local_batch_size == local_batch_size - 1 + or count == output_num_frames * output_height * output_width - 1 + ): + current_video = self.decoder(vae_batch_input) + + if ( + count == output_num_frames * output_height * output_width - 1 + and count % local_batch_size != local_batch_size - 1 + ): + decoded_videos[count - count % local_batch_size :] = current_video[ + : count % local_batch_size + 1 + ] + else: + decoded_videos[count - local_batch_size + 1 : count + 1] = current_video + + vae_batch_input = z.new_zeros( + (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2]) + ) + + count += 1 + + video = z.new_zeros((batch_size, self.config.out_channels, num_frames * rt, height * rs, width * rs)) + video_overlap = ( + self.kernel[0] - self.stride[0], + self.kernel[1] - self.stride[1], + self.kernel[2] - self.stride[2], + ) + + for i in range(output_num_frames): + n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0] + for j in range(output_height): + h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1] + for k in range(output_width): + w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2] + out_video_blend = _prepare_for_blend( + (i, output_num_frames, video_overlap[0]), + (j, output_height, video_overlap[1]), + (k, output_width, video_overlap[2]), + decoded_videos[i * output_height * output_width + j * output_width + k].unsqueeze(0), + ) + video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend + + video = video.permute(0, 2, 1, 3, 4).contiguous() + return video + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + encoder_local_batch_size: int = 2, + decoder_local_batch_size: int = 2, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*): + PyTorch random number generator. + encoder_local_batch_size (`int`, *optional*, defaults to 2): + Local batch size for the encoder's batch inference. + decoder_local_batch_size (`int`, *optional*, defaults to 2): + Local batch size for the decoder's batch inference. + """ + x = sample + posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + +def _prepare_for_blend(n_param, h_param, w_param, x): + # TODO(aryan): refactor + n, n_max, overlap_n = n_param + h, h_max, overlap_h = h_param + w, w_max, overlap_w = w_param + if overlap_n > 0: + if n > 0: # the head overlap part decays from 0 to 1 + x[:, :, 0:overlap_n, :, :] = x[:, :, 0:overlap_n, :, :] * ( + torch.arange(0, overlap_n).float().to(x.device) / overlap_n + ).reshape(overlap_n, 1, 1) + if n < n_max - 1: # the tail overlap part decays from 1 to 0 + x[:, :, -overlap_n:, :, :] = x[:, :, -overlap_n:, :, :] * ( + 1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n + ).reshape(overlap_n, 1, 1) + if h > 0: + x[:, :, :, 0:overlap_h, :] = x[:, :, :, 0:overlap_h, :] * ( + torch.arange(0, overlap_h).float().to(x.device) / overlap_h + ).reshape(overlap_h, 1) + if h < h_max - 1: + x[:, :, :, -overlap_h:, :] = x[:, :, :, -overlap_h:, :] * ( + 1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h + ).reshape(overlap_h, 1) + if w > 0: + x[:, :, :, :, 0:overlap_w] = x[:, :, :, :, 0:overlap_w] * ( + torch.arange(0, overlap_w).float().to(x.device) / overlap_w + ) + if w < w_max - 1: + x[:, :, :, :, -overlap_w:] = x[:, :, :, :, -overlap_w:] * ( + 1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w + ) + return x diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 44f01c46ebe8..66917dce6107 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -564,6 +564,42 @@ def combine_time_height_width(freqs_t, freqs_h, freqs_w): return cos, sin +def get_3d_rotary_pos_embed_allegro( + embed_dim, + crops_coords, + grid_size, + temporal_size, + interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0), + theta: int = 10000, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # TODO(aryan): docs + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + + # Compute dimensions for each axis + dim_t = embed_dim // 3 + dim_h = embed_dim // 3 + dim_w = embed_dim // 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed( + dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False + ) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed( + dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False + ) + freqs_w = get_1d_rotary_pos_embed( + dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False + ) + + return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w + + def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): """ RoPE for image tokens with 2d structure. @@ -684,7 +720,7 @@ def get_1d_rotary_pos_embed( freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] return freqs_cos, freqs_sin elif use_real: - # stable audio + # stable audio, allegro freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] return freqs_cos, freqs_sin @@ -743,6 +779,24 @@ def apply_rotary_emb( return x_out.type_as(x) +def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions): + # TODO(aryan): rewrite + def apply_1d_rope(tokens, pos, cos, sin): + cos = F.embedding(pos, cos)[:, None, :, :] + sin = F.embedding(pos, sin)[:, None, :, :] + x1, x2 = tokens[..., : tokens.shape[-1] // 2], tokens[..., tokens.shape[-1] // 2 :] + tokens_rotated = torch.cat((-x2, x1), dim=-1) + return (tokens.float() * cos + tokens_rotated.float() * sin).to(tokens.dtype) + + (t_cos, t_sin), (h_cos, h_sin), (w_cos, w_sin) = freqs_cis + t, h, w = x.chunk(3, dim=-1) + t = apply_1d_rope(t, positions[0], t_cos, t_sin) + h = apply_1d_rope(h, positions[1], h_cos, h_sin) + w = apply_1d_rope(w, positions[2], w_cos, w_sin) + x = torch.cat([t, h, w], dim=-1) + return x + + class FluxPosEmbed(nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 029c147fcbac..87dec66935da 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -22,10 +22,7 @@ from ..utils import is_torch_version from .activations import get_activation -from .embeddings import ( - CombinedTimestepLabelEmbeddings, - PixArtAlphaCombinedTimestepSizeEmbeddings, -) +from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings class AdaLayerNorm(nn.Module): @@ -266,6 +263,7 @@ def forward( hidden_dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # No modulation happening here. + added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 58787c079ea8..873a2bbecf05 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -14,6 +14,7 @@ from .stable_audio_transformer import StableAudioDiTModel from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel + from .transformer_allegro import AllegroTransformer3DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel from .transformer_sd3 import SD3Transformer2DModel diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py new file mode 100644 index 000000000000..f756399a378a --- /dev/null +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -0,0 +1,422 @@ +# Copyright 2024 The RhymesAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward +from ..attention_processor import AllegroAttnProcessor2_0, Attention +from ..embeddings import PatchEmbed, PixArtAlphaTextProjection +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle + + +logger = logging.get_logger(__name__) + + +@maybe_allow_in_graph +class AllegroTransformerBlock(nn.Module): + r""" + Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model. + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + cross_attention_dim (`int`, defaults to `2304`): + The dimension of the cross attention features. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to be used in feed-forward. + attention_bias (`bool`, defaults to `False`): + Whether or not to use bias in attention projection layers. + only_cross_attention (`bool`, defaults to `False`): + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, defaults to `1e-5`): + Epsilon value for normalization layers. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + attention_bias: bool = False, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + ): + super().__init__() + + # 1. Self Attention + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + processor=AllegroAttnProcessor2_0(), + ) + + # 2. Cross Attention + self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + processor=AllegroAttnProcessor2_0(), + ) + + # 3. Feed Forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + ) + + # 4. Scale-shift + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + temb: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb=None, + ) -> torch.Tensor: + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + temb.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = hidden_states + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + image_rotary_emb=None, + ) + hidden_states = attn_output + hidden_states + + # 2. Feed-forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + + # TODO(aryan): maybe following line is not required + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class AllegroTransformer3DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A 3D Transformer model for video-like data. + + Args: + patch_size (`int`, defaults to `2`): + The size of spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `96`): + The number of channels in each head. + in_channels (`int`, defaults to `4`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `4`): + The number of channels in the output. + num_layers (`int`, defaults to `32`): + The number of layers of Transformer blocks to use. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + cross_attention_dim (`int`, defaults to `2304`): + The dimension of the cross attention features. + attention_bias (`bool`, defaults to `True`): + Whether or not to use bias in the attention projection layers. + sample_height (`int`, defaults to `90`): + The height of the input latents. + sample_width (`int`, defaults to `160`): + The width of the input latents. + sample_frames (`int`, defaults to `22`): + The number of frames in the input latents. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + norm_elementwise_affine (`bool`, defaults to `False`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, defaults to `1e-6`): + The epsilon value to use in normalization layers. + caption_channels (`int`, defaults to `4096`): + Number of channels to use for projecting the caption embeddings. + interpolation_scale_h (`float`, defaults to `2.0`): + Scaling factor to apply in 3D positional embeddings across height dimension. + interpolation_scale_w (`float`, defaults to `2.0`): + Scaling factor to apply in 3D positional embeddings across width dimension. + interpolation_scale_t (`float`, defaults to `2.2`): + Scaling factor to apply in 3D positional embeddings across time dimension. + """ + + @register_to_config + def __init__( + self, + patch_size: int = 2, + patch_size_t: int = 1, + num_attention_heads: int = 24, + attention_head_dim: int = 96, + in_channels: int = 4, + out_channels: int = 4, + num_layers: int = 32, + dropout: float = 0.0, + cross_attention_dim: int = 2304, + attention_bias: bool = True, + sample_height: int = 90, + sample_width: int = 160, + sample_frames: int = 22, + activation_fn: str = "gelu-approximate", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + caption_channels: int = 4096, + interpolation_scale_h: float = 2.0, + interpolation_scale_w: float = 2.0, + interpolation_scale_t: float = 2.2, + ): + super().__init__() + + self.inner_dim = num_attention_heads * attention_head_dim + + interpolation_scale_t = ( + interpolation_scale_t + if interpolation_scale_t is not None + else ((sample_frames - 1) // 16 + 1) + if sample_frames % 2 == 1 + else sample_frames // 16 + ) + interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30 + interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40 + + # 1. Patch embedding + self.pos_embed = PatchEmbed( + height=sample_height, + width=sample_width, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + pos_embed_type=None, + ) + + # 2. Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + AllegroTransformerBlock( + self.inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + + # 3. Output projection & norm + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels) + + # 4. Timestep embeddings + self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False) + + # 5. Caption projection + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + return_dict: bool = True, + ): + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t = self.config.patch_size_t + p = self.config.patch_size + + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, None + if attention_mask is not None and attention_mask.ndim == 4: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + # b, frame+use_image_num, h, w -> a video with images + # b, 1, h, w -> only images + attention_mask = attention_mask.to(hidden_states.dtype) + attention_mask = attention_mask[:, :num_frames] # [batch_size, num_frames, height, width] + + if attention_mask.numel() > 0: + attention_mask = attention_mask.unsqueeze(1) # [batch_size, 1, num_frames, height, width] + attention_mask = F.max_pool3d(attention_mask, kernel_size=(p_t, p, p), stride=(p_t, p, p)) + attention_mask = attention_mask.flatten(1).view(batch_size, 1, -1) + + attention_mask = ( + (1 - attention_mask.bool().to(hidden_states.dtype)) * -10000.0 if attention_mask.numel() > 0 else None + ) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Timestep embeddings + timestep, embedded_timestep = self.adaln_single( + timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # 2. Patch embeddings + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.pos_embed(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) + + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1]) + + # 3. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + # TODO(aryan): Implement gradient checkpointing + if self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + timestep, + attention_mask, + encoder_attention_mask, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=timestep, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + image_rotary_emb=image_rotary_emb, + ) + + # 4. Output normalization & projection + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # 5. Unpatchify + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.reshape(batch_size, -1, num_frames, height, width) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7366520f4692..634088f1b51a 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -116,6 +116,7 @@ "VersatileDiffusionTextToImagePipeline", ] ) + _import_structure["allegro"] = ["AllegroPipeline"] _import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"] _import_structure["animatediff"] = [ "AnimateDiffPipeline", @@ -454,6 +455,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_objects import * else: + from .allegro import AllegroPipeline from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline from .animatediff import ( AnimateDiffControlNetPipeline, diff --git a/src/diffusers/pipelines/allegro/__init__.py b/src/diffusers/pipelines/allegro/__init__.py new file mode 100644 index 000000000000..2162b825e0a2 --- /dev/null +++ b/src/diffusers/pipelines/allegro/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_allegro"] = ["AllegroPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_allegro import AllegroPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py new file mode 100644 index 000000000000..9314960f9618 --- /dev/null +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -0,0 +1,918 @@ +# Copyright 2024 The RhymesAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import math +import re +import urllib.parse as ul +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro +from ...models.embeddings import get_3d_rotary_pos_embed_allegro +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import AllegroPipelineOutput + + +logger = logging.get_logger(__name__) + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoencoderKLAllegro, AllegroPipeline + >>> from diffusers.utils import export_to_video + + >>> vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32) + >>> pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", vae=vae, torch_dtype=torch.bfloat16).to("cuda") + + >>> prompt = ( + ... "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, " + ... "the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this " + ... "location might be a popular spot for docking fishing boats." + ... ) + >>> video = pipe(prompt, guidance_scale=7.5, max_sequence_length=512).frames[0] + >>> export_to_video(video, "output.mp4", fps=15) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AllegroPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Allegro. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AllegroAutoEncoderKL3D`]): + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. PixArt-Alpha uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`AllegroTransformer3DModel`]): + A text conditioned `AllegroTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLAllegro, + transformer: AllegroTransformer3DModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->512, num_images_per_prompt->num_videos_per_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 512, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_videos_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 512): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + num_frames, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if num_frames <= 0: + raise ValueError(f"`num_frames` have to be positive but is {num_frames}.") + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if num_frames % 2 == 0: + num_frames = math.ceil(num_frames / self.vae_scale_factor_temporal) + else: + num_frames = math.ceil((num_frames - 1) / self.vae_scale_factor_temporal) + 1 + + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = 1 / self.vae.config.scaling_factor * latents + frames = self.vae.decode(latents).sample + frames = frames.permute(0, 2, 1, 3, 4) # [batch_size, channels, num_frames, height, width] + return frames + + def _prepare_rotary_positional_embeddings( + self, + batch_size: int, + height: int, + width: int, + num_frames: int, + device: torch.device, + ): + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + start, stop = (0, 0), (grid_height, grid_width) + freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w = get_3d_rotary_pos_embed_allegro( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=(start, stop), + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + interpolation_scale=( + self.transformer.config.interpolation_scale_t, + self.transformer.config.interpolation_scale_h, + self.transformer.config.interpolation_scale_w, + ), + ) + + grid_t = torch.from_numpy(grid_t).to(device=device, dtype=torch.long) + grid_h = torch.from_numpy(grid_h).to(device=device, dtype=torch.long) + grid_w = torch.from_numpy(grid_w).to(device=device, dtype=torch.long) + + pos = torch.cartesian_prod(grid_t, grid_h, grid_w) + pos = pos.reshape(-1, 3).transpose(0, 1).reshape(3, 1, -1).contiguous() + grid_t, grid_h, grid_w = pos + + freqs_t = (freqs_t[0].to(device=device), freqs_t[1].to(device=device)) + freqs_h = (freqs_h[0].to(device=device), freqs_h[1].to(device=device)) + freqs_w = (freqs_w[0].to(device=device), freqs_w[1].to(device=device)) + + return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 100, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + clean_caption: bool = True, + max_sequence_length: int = 512, + ) -> Union[AllegroPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + num_frames: (`int`, *optional*, defaults to 88): + The number controls the generated video frames. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated video. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate video. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + max_sequence_length (`int` defaults to `512`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated videos. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + num_frames = num_frames or self.transformer.config.sample_frames * self.vae_scale_factor_temporal + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + + self.check_inputs( + prompt, + num_frames, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + if prompt_embeds.ndim == 3: + prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare rotary embeddings + image_rotary_emb = self._prepare_rotary_positional_embeddings( + batch_size, height, width, latents.size(2), device + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + video = self.decode_latents(latents) + video = video[:, :, :num_frames, :height, :width] + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AllegroPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/allegro/pipeline_output.py b/src/diffusers/pipelines/allegro/pipeline_output.py new file mode 100644 index 000000000000..6a721783ca86 --- /dev/null +++ b/src/diffusers/pipelines/allegro/pipeline_output.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class AllegroPipelineOutput(BaseOutput): + r""" + Output class for Allegro pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 10d0399a6761..8a87b04a66cb 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class AllegroTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AsymmetricAutoencoderKL(metaclass=DummyObject): _backends = ["torch"] @@ -47,6 +62,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLAllegro(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLCogVideoX(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 9046a4f73533..83d160b08df4 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class AllegroPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_allegro.py b/tests/models/transformers/test_models_transformer_allegro.py new file mode 100644 index 000000000000..ad8b7a3824ba --- /dev/null +++ b/tests/models/transformers/test_models_transformer_allegro.py @@ -0,0 +1,79 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AllegroTransformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class AllegroTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = AllegroTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + num_frames = 2 + height = 8 + width = 8 + embedding_dim = 16 + sequence_length = 16 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim // 2)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + } + + @property + def input_shape(self): + return (4, 2, 8, 8) + + @property + def output_shape(self): + return (4, 2, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings. + "num_attention_heads": 2, + "attention_head_dim": 8, + "in_channels": 4, + "out_channels": 4, + "num_layers": 1, + "cross_attention_dim": 16, + "sample_width": 8, + "sample_height": 8, + "sample_frames": 8, + "caption_channels": 8, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict diff --git a/tests/pipelines/allegro/__init__.py b/tests/pipelines/allegro/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py new file mode 100644 index 000000000000..d09fc0488378 --- /dev/null +++ b/tests/pipelines/allegro/test_allegro.py @@ -0,0 +1,337 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5Config, T5EncoderModel + +from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler +from diffusers.utils.testing_utils import ( + enable_full_determinism, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = AllegroPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = AllegroTransformer3DModel( + num_attention_heads=2, + attention_head_dim=12, + in_channels=4, + out_channels=4, + num_layers=1, + cross_attention_dim=24, + sample_width=8, + sample_height=8, + sample_frames=8, + caption_channels=24, + ) + + torch.manual_seed(0) + vae = AutoencoderKLAllegro( + in_channels=3, + out_channels=3, + down_block_types=( + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + ), + up_block_types=( + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + latent_channels=4, + layers_per_block=1, + norm_num_groups=2, + temporal_compression_ratio=4, + ) + + # TODO(aryan): Only for now, since VAE decoding without tiling is not yet implemented here + vae.enable_tiling() + + torch.manual_seed(0) + scheduler = DDIMScheduler() + + text_encoder_config = T5Config( + **{ + "d_ff": 37, + "d_kv": 8, + "d_model": 24, + "num_decoder_layers": 2, + "num_heads": 4, + "num_layers": 2, + "relative_attention_num_buckets": 8, + "vocab_size": 1103, + } + ) + text_encoder = T5EncoderModel(text_encoder_config) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "num_frames": 8, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + @unittest.skip("Decoding without tiling is not yet implemented") + def test_save_load_local(self): + pass + + @unittest.skip("Decoding without tiling is not yet implemented") + def test_save_load_optional_components(self): + pass + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (8, 3, 16, 16)) + expected_video = torch.randn(8, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + # TODO(aryan) + @unittest.skip("Decoding without tiling is not yet implemented.") + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_overlap_factor_height=1 / 12, + tile_overlap_factor_width=1 / 12, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + +@slow +@require_torch_gpu +class AllegroPipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_allegro(self): + generator = torch.Generator("cpu").manual_seed(0) + + pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", torch_dtype=torch.float16) + pipe.enable_model_cpu_offload() + prompt = self.prompt + + videos = pipe( + prompt=prompt, + height=720, + width=1280, + num_frames=88, + generator=generator, + num_inference_steps=2, + output_type="pt", + ).frames + + video = videos[0] + expected_video = torch.randn(1, 88, 720, 1280, 3).numpy() + + max_diff = numpy_cosine_similarity_distance(video, expected_video) + assert max_diff < 1e-3, f"Max diff is too high. got {video}" diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 3e6f9d1278e8..295a94c1d2e4 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1103,6 +1103,7 @@ def _test_inference_batch_consistent( logger.setLevel(level=diffusers.logging.WARNING) for batch_size, batched_input in zip(batch_sizes, batched_inputs): + print(batch_size, batched_input) output = pipe(**batched_input) assert len(output[0]) == batch_size