From 0eeee618cffa8c648f4a753ace243100f465ce8d Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Wed, 22 Nov 2023 17:27:56 +0200 Subject: [PATCH 01/23] Adds an advanced version of the SD-XL DreamBooth LoRA training script supporting pivotal tuning (#5883) * sdxl dreambooth lora training script with pivotal tuning * bug fix - args missing from parse_args * code quality fixes * comment unnecessary code from TokenEmbedding handler class * fixup --------- Co-authored-by: Linoy Tsaban --- .../train_dreambooth_lora_sdxl_advanced.py | 1968 +++++++++++++++++ 1 file changed, 1968 insertions(+) create mode 100644 examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py new file mode 100644 index 000000000000..f032634a11f0 --- /dev/null +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -0,0 +1,1968 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import gc +import hashlib +import itertools +import logging +import math +import os +import shutil +import warnings +from pathlib import Path +from typing import List, Optional + +import numpy as np +import torch +import torch.nn.functional as F + +# imports of the TokenEmbeddingsHandler class +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from PIL.ImageOps import exif_transpose +from safetensors.torch import save_file +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DPMSolverMultistepScheduler, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) +from diffusers.loaders import LoraLoaderMixin +from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict +from diffusers.optimization import get_scheduler +from diffusers.training_utils import compute_snr, unet_lora_state_dict +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.24.0.dev0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + instance_prompt=str, + validation_prompt=str, + repo_folder=None, + vae_path=None, +): + img_str = "widget:\n" if images else "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f""" + - text: '{validation_prompt if validation_prompt else ' ' }' + output: + url: >- + "image_{i}.png" + """ + + yaml = f""" +--- +tags: +- stable-diffusion-xl +- stable-diffusion-xl-diffusers +- text-to-image +- diffusers +- lora +- template:sd-lora +widget: +{img_str} +--- +base_model: {base_model} +instance_prompt: {instance_prompt} +license: openrail++ +--- + """ + + model_card = f""" +# SDXL LoRA DreamBooth - {repo_id} + + + +## Model description + +These are {repo_id} LoRA adaption weights for {base_model}. +The weights were trained using [DreamBooth](https://dreambooth.github.io/). +LoRA for the text encoder was enabled: {train_text_encoder}. +Special VAE used for training: {vae_path}. + +## Trigger words + +You should use {instance_prompt} to trigger the image generation. + +## Download model + +Weights for this model are available in Safetensors format. + +[Download]({repo_id}/tree/main) them in the Files & versions tab. + +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that ๐Ÿค— Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--token_abstraction", + default="TOK", + help="identifier specifying the instance(or instances) as used in instance_prompt, validation prompt, " + "captions - e.g. TOK", + ) + + parser.add_argument( + "--num_new_tokens_per_abstraction", + default=2, + help="number of new tokens inserted to the tokenizers per token_abstraction value when " + "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " + "tokens - ", + ) + + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="lora-dreambooth-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--crops_coords_top_left_h", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--crops_coords_top_left_w", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + + parser.add_argument( + "--train_text_encoder_ti", + action="store_true", + help=("Whether to use textual inversion"), + ) + + parser.add_argument( + "--train_text_encoder_ti_frac", + type=float, + default=0.5, + help=("The percentage of epochs to perform textual inversion"), + ) + + parser.add_argument( + "--train_text_encoder_frac", + type=float, + default=0.5, + help=("The percentage of epochs to perform text encoder tuning"), + ) + + parser.add_argument( + "--optimizer", + type=str, + default="adamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "fp16", "bf16"], + help=( + "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + if args.train_text_encoder and args.train_text_encoder_ti: + raise ValueError( + "Specify only one of `--train_text_encoder` or `--train_text_encoder_ti. " + "For full LoRA text encoder training check --train_text_encoder, for textual " + "inversion training check `--train_text_encoder_ti`" + ) + + if args.train_text_encoder_ti: + if isinstance(args.token_abstraction, str): + args.token_abstraction = [args.token_abstraction] + elif isinstance(args.token_abstraction, List): + args.token_abstraction = args.token_abstraction + else: + raise ValueError( + f"Unsupported type for --args.token_abstraction: {type(args.token_abstraction)}. " + f"Supported types are: str (for a single instance identifier) or List[str] (for multiple concepts)" + ) + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +# Taken from https://github.com/replicate/cog-sdxl/blob/main/dataset_and_utils.py +class TokenEmbeddingsHandler: + def __init__(self, text_encoders, tokenizers): + self.text_encoders = text_encoders + self.tokenizers = tokenizers + + self.train_ids: Optional[torch.Tensor] = None + self.inserting_toks: Optional[List[str]] = None + self.embeddings_settings = {} + + def initialize_new_tokens(self, inserting_toks: List[str]): + idx = 0 + for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): + assert isinstance(inserting_toks, list), "inserting_toks should be a list of strings." + assert all( + isinstance(tok, str) for tok in inserting_toks + ), "All elements in inserting_toks should be strings." + + self.inserting_toks = inserting_toks + special_tokens_dict = {"additional_special_tokens": self.inserting_toks} + tokenizer.add_special_tokens(special_tokens_dict) + text_encoder.resize_token_embeddings(len(tokenizer)) + + self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) + + # random initialization of new tokens + std_token_embedding = text_encoder.text_model.embeddings.token_embedding.weight.data.std() + + print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}") + + text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] = ( + torch.randn(len(self.train_ids), text_encoder.text_model.config.hidden_size) + .to(device=self.device) + .to(dtype=self.dtype) + * std_token_embedding + ) + self.embeddings_settings[ + f"original_embeddings_{idx}" + ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() + self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding + + inu = torch.ones((len(tokenizer),), dtype=torch.bool) + inu[self.train_ids] = False + + self.embeddings_settings[f"index_no_updates_{idx}"] = inu + + print(self.embeddings_settings[f"index_no_updates_{idx}"].shape) + + idx += 1 + + def save_embeddings(self, file_path: str): + assert self.train_ids is not None, "Initialize new tokens before saving embeddings." + tensors = {} + for idx, text_encoder in enumerate(self.text_encoders): + assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len( + self.tokenizers[0] + ), "Tokenizers should be the same." + new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] + tensors[f"text_encoders_{idx}"] = new_token_embeddings + + save_file(tensors, file_path) + + @property + def dtype(self): + return self.text_encoders[0].dtype + + @property + def device(self): + return self.text_encoders[0].device + + # def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder): + # # Assuming new tokens are of the format + # self.inserting_toks = [f"" for i in range(loaded_embeddings.shape[0])] + # special_tokens_dict = {"additional_special_tokens": self.inserting_toks} + # tokenizer.add_special_tokens(special_tokens_dict) + # text_encoder.resize_token_embeddings(len(tokenizer)) + # + # self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) + # assert self.train_ids is not None, "New tokens could not be converted to IDs." + # text_encoder.text_model.embeddings.token_embedding.weight.data[ + # self.train_ids + # ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype) + + @torch.no_grad() + def retract_embeddings(self): + for idx, text_encoder in enumerate(self.text_encoders): + index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] + text_encoder.text_model.embeddings.token_embedding.weight.data[index_no_updates] = ( + self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] + .to(device=text_encoder.device) + .to(dtype=text_encoder.dtype) + ) + + # for the parts that were updated, we need to normalize them + # to have the same std as before + std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"] + + index_updates = ~index_no_updates + new_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] + off_ratio = std_token_embedding / new_embeddings.std() + + new_embeddings = new_embeddings * (off_ratio**0.1) + text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings + + # def load_embeddings(self, file_path: str): + # with safe_open(file_path, framework="pt", device=self.device.type) as f: + # for idx in range(len(self.text_encoders)): + # text_encoder = self.text_encoders[idx] + # tokenizer = self.tokenizers[idx] + # + # loaded_embeddings = f.get_tensor(f"text_encoders_{idx}") + # self._load_embeddings(loaded_embeddings, tokenizer, text_encoder) + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + token_abstraction_dict=None, # token mapping for textual inversion + size=1024, + repeats=1, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + self.token_abstraction_dict = token_abstraction_dict + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.instance_images[index % self.num_instance_images] + instance_image = exif_transpose(instance_image) + + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + if args.train_text_encoder_ti: + # replace instances of --token_abstraction in caption with the new tokens: "" etc. + for token_abs, token_replacement in self.token_abstraction_dict.items(): + caption = caption.replace(token_abs, "".join(token_replacement)) + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # costum prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt, add_special_tokens=False): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + add_special_tokens=add_special_tokens, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): + prompt_embeds_list = [] + + for i, text_encoder in enumerate(text_encoders): + if tokenizers is not None: + tokenizer = tokenizers[i] + text_input_ids = tokenize_prompt(tokenizer, prompt) + else: + assert text_input_ids_list is not None + text_input_ids = text_input_ids_list[i] + + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + kwargs_handlers=[kwargs], + ) + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "fp16": + torch_dtype = torch.float16 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + revision=args.revision, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + ) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + + if args.train_text_encoder_ti: + token_abstraction_dict = {} + token_idx = 0 + for i, token in enumerate(args.token_abstraction): + token_abstraction_dict[token] = [ + f"" for j in range(args.num_new_tokens_per_abstraction) + ] + token_idx += args.num_new_tokens_per_abstraction - 1 + + # replace instances of --token_abstraction in --instance_prompt with the new tokens: "" etc. + for token_abs, token_replacement in token_abstraction_dict.items(): + args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement)) + if args.with_prior_preservation: + args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement)) + + # initialize the new tokens for textual inversion + embedding_handler = TokenEmbeddingsHandler( + [text_encoder_one, text_encoder_two], [tokenizer_one, tokenizer_two] + ) + inserting_toks = [] + for new_tok in token_abstraction_dict.values(): + inserting_toks.extend(new_tok) + embedding_handler.initialize_new_tokens(inserting_toks=inserting_toks) + + # We only train the additional adapter LoRA layers + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + unet.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + + # The VAE is always in float32 to avoid NaN losses. + vae.to(accelerator.device, dtype=torch.float32) + + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, " + "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder_one.gradient_checkpointing_enable() + text_encoder_two.gradient_checkpointing_enable() + + # now we will add new LoRA weights to the attention layers + # Set correct lora layers + unet_lora_parameters = [] + for attn_processor_name, attn_processor in unet.attn_processors.items(): + # Parse the attention module. + attn_module = unet + for n in attn_processor_name.split(".")[:-1]: + attn_module = getattr(attn_module, n) + + # Set the `lora_layer` attribute of the attention-related matrices. + attn_module.to_q.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank + ) + ) + attn_module.to_k.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank + ) + ) + attn_module.to_v.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank + ) + ) + attn_module.to_out[0].set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_out[0].in_features, + out_features=attn_module.to_out[0].out_features, + rank=args.rank, + ) + ) + + # Accumulate the LoRA params to optimize. + unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) + unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) + + # The text encoder comes from ๐Ÿค— transformers, so we cannot directly modify it. + # So, instead, we monkey-patch the forward calls of its attention-blocks. + if args.train_text_encoder: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder( + text_encoder_one, dtype=torch.float32, rank=args.rank + ) + text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder( + text_encoder_two, dtype=torch.float32, rank=args.rank + ) + + # if we use textual inversion, we freeze all parameters except for the token embeddings + # in text encoder + elif args.train_text_encoder_ti: + text_lora_parameters_one = [] + for name, param in text_encoder_one.named_parameters(): + if "token_embedding" in name: + param.requires_grad = True + text_lora_parameters_one.append(param) + else: + param.requires_grad = False + text_lora_parameters_two = [] + for name, param in text_encoder_two.named_parameters(): + if "token_embedding" in name: + param.requires_grad = True + text_lora_parameters_two.append(param) + else: + param.requires_grad = False + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + text_encoder_one_lora_layers_to_save = None + text_encoder_two_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_lora_layers_to_save = unet_lora_state_dict(model) + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + StableDiffusionXLPipeline.save_lora_weights( + output_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + unet_ = None + text_encoder_one_ = None + text_encoder_two_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(accelerator.unwrap_model(unet))): + unet_ = model + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): + text_encoder_one_ = model + elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): + text_encoder_two_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + + text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} + LoraLoaderMixin.load_lora_into_text_encoder( + text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ + ) + + text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} + LoraLoaderMixin.load_lora_into_text_encoder( + text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ + ) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training + freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) + + # Optimization parameters + unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} + if not freeze_text_encoder: + # different learning rate for text encoder and unet + text_lora_parameters_one_with_lr = { + "params": text_lora_parameters_one, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + text_lora_parameters_two_with_lr = { + "params": text_lora_parameters_two, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + params_to_optimize = [ + unet_lora_parameters_with_lr, + text_lora_parameters_one_with_lr, + text_lora_parameters_two_with_lr, + ] + else: + params_to_optimize = [unet_lora_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warn( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warn( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warn( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warn( + f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + params_to_optimize[2]["lr"] = args.learning_rate + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + # Computes additional embeddings/ids required by the SDXL UNet. + # regular text embeddings (when `train_text_encoder` is not True) + # pooled text embeddings + # time ids + + def compute_time_ids(): + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + original_size = (args.resolution, args.resolution) + target_size = (args.resolution, args.resolution) + crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + return add_time_ids + + if not args.train_text_encoder: + tokenizers = [tokenizer_one, tokenizer_two] + text_encoders = [text_encoder_one, text_encoder_two] + + def compute_text_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds + + # Handle instance prompt. + instance_time_ids = compute_time_ids() + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if freeze_text_encoder and not train_dataset.custom_instance_prompts: + instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + class_time_ids = compute_time_ids() + if freeze_text_encoder: + class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers + ) + + # Clear the memory here + if freeze_text_encoder and not train_dataset.custom_instance_prompts: + del tokenizers, text_encoders + gc.collect() + torch.cuda.empty_cache() + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + add_time_ids = instance_time_ids + if args.with_prior_preservation: + add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0) + + # if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion + add_special_tokens = True if args.train_text_encoder_ti else False + + if not train_dataset.custom_instance_prompts: + if freeze_text_encoder: + prompt_embeds = instance_prompt_hidden_states + unet_add_text_embeds = instance_pooled_prompt_embeds + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) + # if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the + # batch prompts on all training steps + else: + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, add_special_tokens) + tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt, add_special_tokens) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, add_special_tokens) + class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt, add_special_tokens) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if not freeze_text_encoder: + unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + if args.train_text_encoder: + num_train_epochs_text_encoder = int(args.train_text_encoder_frac * args.num_train_epochs) + elif args.train_text_encoder_ti: # args.train_text_encoder_ti + num_train_epochs_text_encoder = int(args.train_text_encoder_ti_frac * args.num_train_epochs) + + for epoch in range(first_epoch, args.num_train_epochs): + # if performing any kind of optimization of text_encoder params + if args.train_text_encoder or args.train_text_encoder_ti: + if epoch == num_train_epochs_text_encoder: + print("PIVOT HALFWAY", epoch) + # stopping optimization of text_encoder params + params_to_optimize = params_to_optimize[:1] + # reinitializing the optimizer to optimize only on unet params + if args.optimizer.lower() == "prodigy": + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + else: # AdamW or 8-bit-AdamW + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + else: + # still optimizng the text encoder + text_encoder_one.train() + text_encoder_two.train() + # set top parameter requires_grad = True for gradient checkpointing works + if args.train_text_encoder: + text_encoder_one.text_model.embeddings.requires_grad_(True) + text_encoder_two.text_model.embeddings.requires_grad_(True) + + unet.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + prompts = batch["prompts"] + print(prompts) + # encode batch prompts when custom prompts are provided for each image - + if train_dataset.custom_instance_prompts: + if freeze_text_encoder: + prompt_embeds, unet_add_text_embeds = compute_text_embeddings( + prompts, text_encoders, tokenizers + ) + + else: + tokens_one = tokenize_prompt(tokenizer_one, prompts, add_special_tokens) + tokens_two = tokenize_prompt(tokenizer_two, prompts, add_special_tokens) + + # Convert images to latent space + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + model_input = model_input.to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # Calculate the elements to repeat depending on the use of prior-preservation and custom captions. + if not train_dataset.custom_instance_prompts: + elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz + elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz + + else: + elems_to_repeat_text_embeds = 1 + elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz + + # Predict the noise residual + if freeze_text_encoder: + unet_added_conditions = { + "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1), + "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), + } + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) + model_pred = unet( + noisy_model_input, + timesteps, + prompt_embeds_input, + added_cond_kwargs=unet_added_conditions, + ).sample + else: + unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} + prompt_embeds, pooled_prompt_embeds = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=None, + prompt=None, + text_input_ids_list=[tokens_one, tokens_two], + ) + unet_added_conditions.update( + {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)} + ) + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) + model_pred = unet( + noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions + ).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + base_weight = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two) + if (args.train_text_encoder or args.train_text_encoder_ti) + else unet_lora_parameters + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # every step, we reset the embeddings to the original embeddings. + if args.train_text_encoder_ti: + for idx, text_encoder in enumerate(text_encoders): + embedding_handler.retract_embeddings() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + if not args.train_text_encoder: + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=accelerator.unwrap_model(text_encoder_one), + text_encoder_2=accelerator.unwrap_model(text_encoder_two), + unet=accelerator.unwrap_model(unet), + revision=args.revision, + torch_dtype=weight_dtype, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, **scheduler_args + ) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + pipeline_args = {"prompt": args.validation_prompt} + + with torch.cuda.amp.autocast(): + images = [ + pipeline(**pipeline_args, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + unet = unet.to(torch.float32) + unet_lora_layers = unet_lora_state_dict(unet) + + if args.train_text_encoder: + text_encoder_one = accelerator.unwrap_model(text_encoder_one) + text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32)) + text_encoder_two = accelerator.unwrap_model(text_encoder_two) + text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32)) + else: + text_encoder_lora_layers = None + text_encoder_2_lora_layers = None + + StableDiffusionXLPipeline.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + text_encoder_2_lora_layers=text_encoder_2_lora_layers, + ) + + # Final inference + # Load previous pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + pipeline = pipeline.to(accelerator.device) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + if args.push_to_hub: + if args.train_text_encoder_ti: + embedding_handler.save_embeddings( + f"{args.output_dir}/embeddings.safetensors", + ) + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + vae_path=args.pretrained_vae_model_name_or_path, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) From 5ffa6032444abb45c10db8c23c3cbe155f956069 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Thu, 23 Nov 2023 13:11:50 +0200 Subject: [PATCH 02/23] [bug fix] fix small bug in readme template of sdxl lora training script (#5906) * readme bug fix * style fix --------- Co-authored-by: Linoy Tsaban --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index dd7b29ca8842..9285c0e9fae7 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -112,7 +112,7 @@ def save_model_card( img_str += f""" - text: '{validation_prompt if validation_prompt else ' ' }' output: - url: >- + url: "image_{i}.png" """ @@ -125,7 +125,6 @@ def save_model_card( - diffusers - lora - template:sd-lora -widget: {img_str} --- base_model: {base_model} From 3003ff4947ea43fb56aa0df3da61c85652f24c69 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Thu, 23 Nov 2023 20:08:49 +0200 Subject: [PATCH 03/23] [bug fix] fix small bug in readme template of sdxl lora training script (#5914) readme improvement and metadata fix --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 9285c0e9fae7..f4e7887c1c13 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -126,7 +126,6 @@ def save_model_card( - lora - template:sd-lora {img_str} ---- base_model: {base_model} instance_prompt: {instance_prompt} license: openrail++ @@ -141,8 +140,11 @@ def save_model_card( ## Model description These are {repo_id} LoRA adaption weights for {base_model}. + The weights were trained using [DreamBooth](https://dreambooth.github.io/). + LoRA for the text encoder was enabled: {train_text_encoder}. + Special VAE used for training: {vae_path}. ## Trigger words From e5f232f76bc8a6f5167285f414f208517861083f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 24 Nov 2023 20:36:33 +0530 Subject: [PATCH 04/23] [Docs] add: 8bit inference with pixart alpha (#5814) * add: 8bit inference with pixart alpha * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * add: note on 4bit. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * address comment --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Patrick von Platen --- docs/source/en/api/pipelines/pixart.md | 106 +++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/docs/source/en/api/pipelines/pixart.md b/docs/source/en/api/pipelines/pixart.md index 6fa44cd508e4..7d8ff2b36bf2 100644 --- a/docs/source/en/api/pipelines/pixart.md +++ b/docs/source/en/api/pipelines/pixart.md @@ -35,6 +35,112 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) +## Inference with under 8GB GPU VRAM + +Run the [`PixArtAlphaPipeline`] with under 8GB GPU VRAM by loading the text encoder in 8-bit precision. Let's walk through a full-fledged example. + +First, install the [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library: + +```bash +pip install -U bitsandbytes +``` + +Then load the text encoder in 8-bit: + +```python +from transformers import T5EncoderModel +from diffusers import PixArtAlphaPipeline +import torch + +text_encoder = T5EncoderModel.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + subfolder="text_encoder", + load_in_8bit=True, + device_map="auto", + +) +pipe = PixArtAlphaPipeline.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + text_encoder=text_encoder, + transformer=None, + device_map="auto" +) +``` + +Now, use the `pipe` to encode a prompt: + +```python +with torch.no_grad(): + prompt = "cute cat" + prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt) +``` + +Since text embeddings have been computed, remove the `text_encoder` and `pipe` from the memory, and free up som GPU VRAM: + +```python +import gc + +def flush(): + gc.collect() + torch.cuda.empty_cache() + +del text_encoder +del pipe +flush() +``` + +Then compute the latents with the prompt embeddings as inputs: + +```python +pipe = PixArtAlphaPipeline.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", + text_encoder=None, + torch_dtype=torch.float16, +).to("cuda") + +latents = pipe( + negative_prompt=None, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + num_images_per_prompt=1, + output_type="latent", +).images + +del pipe.transformer +flush() +``` + + + +Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded. + + + +Once the latents are computed, pass it off to the VAE to decode into a real image: + +```python +with torch.no_grad(): + image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0] +image = pipe.image_processor.postprocess(image, output_type="pil")[0] +image.save("cat.png") +``` + +By deleting components you aren't using and flushing the GPU VRAM, you should be able to run [`PixArtAlphaPipeline`] with under 8GB GPU VRAM. + +![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/8bits_cat.png) + +If you want a report of your memory-usage, run this [script](https://gist.github.com/sayakpaul/3ae0f847001d342af27018a96f467e4e). + + + +Text embeddings computed in 8-bit can impact the quality of the generated images because of the information loss in the representation space caused by the reduced precision. It's recommended to compare the outputs with and without 8-bit. + + + +While loading the `text_encoder`, you set `load_in_8bit` to `True`. You could also specify `load_in_4bit` to bring your memory requirements down even further to under 7GB. + ## PixArtAlphaPipeline [[autodoc]] PixArtAlphaPipeline From b978334d71ebc07e92aad2e5463da3b3a6c8c0e2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Nov 2023 17:46:00 +0100 Subject: [PATCH 05/23] [@cene555][Kandinsky 3.0] Add Kandinsky 3.0 (#5913) * finalize * finalize * finalize * add slow test * add slow test * add slow test * Fix more * add slow test * fix more * fix more * fix more * fix more * fix more * fix more * fix more * fix more * fix more * Better * Fix more * Fix more * add slow test * Add auto pipelines * add slow test * Add all * add slow test * add slow test * add slow test * add slow test * add slow test * Apply suggestions from code review * add slow test * add slow test --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/pipelines/kandinsky3.md | 24 + scripts/convert_kandinsky3_unet.py | 98 +++ src/diffusers/__init__.py | 6 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/attention_processor.py | 41 +- src/diffusers/models/unet_kandi3.py | 589 ++++++++++++++++++ src/diffusers/pipelines/__init__.py | 5 + src/diffusers/pipelines/auto_pipeline.py | 3 + .../pipelines/kandinsky3/__init__.py | 49 ++ .../kandinsky3/kandinsky3_pipeline.py | 452 ++++++++++++++ .../kandinsky3/kandinsky3img2img_pipeline.py | 460 ++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 15 + .../dummy_torch_and_transformers_objects.py | 30 + tests/convert_kandinsky3_unet.py | 98 +++ tests/pipelines/kandinsky3/__init__.py | 0 tests/pipelines/kandinsky3/test_kandinsky3.py | 237 +++++++ 17 files changed, 2110 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/api/pipelines/kandinsky3.md create mode 100644 scripts/convert_kandinsky3_unet.py create mode 100644 src/diffusers/models/unet_kandi3.py create mode 100644 src/diffusers/pipelines/kandinsky3/__init__.py create mode 100644 src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py create mode 100644 src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py create mode 100755 tests/convert_kandinsky3_unet.py create mode 100644 tests/pipelines/kandinsky3/__init__.py create mode 100644 tests/pipelines/kandinsky3/test_kandinsky3.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d2583121418e..e855ea36e8cf 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -278,6 +278,8 @@ title: Kandinsky 2.1 - local: api/pipelines/kandinsky_v22 title: Kandinsky 2.2 + - local: api/pipelines/kandinsky3 + title: Kandinsky 3 - local: api/pipelines/latent_consistency_models title: Latent Consistency Models - local: api/pipelines/latent_diffusion diff --git a/docs/source/en/api/pipelines/kandinsky3.md b/docs/source/en/api/pipelines/kandinsky3.md new file mode 100644 index 000000000000..cc4f87d47f58 --- /dev/null +++ b/docs/source/en/api/pipelines/kandinsky3.md @@ -0,0 +1,24 @@ + + +# Kandinsky 3 + +TODO + +## Kandinsky3Pipeline + +[[autodoc]] Kandinsky3Pipeline + - all + - __call__ + +## Kandinsky3Img2ImgPipeline + +[[autodoc]] Kandinsky3Img2ImgPipeline + - all + - __call__ diff --git a/scripts/convert_kandinsky3_unet.py b/scripts/convert_kandinsky3_unet.py new file mode 100644 index 000000000000..4fe8c54eb7fc --- /dev/null +++ b/scripts/convert_kandinsky3_unet.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +import argparse +import fnmatch + +from safetensors.torch import load_file + +from diffusers import Kandinsky3UNet + + +MAPPING = { + "to_time_embed.1": "time_embedding.linear_1", + "to_time_embed.3": "time_embedding.linear_2", + "in_layer": "conv_in", + "out_layer.0": "conv_norm_out", + "out_layer.2": "conv_out", + "down_samples": "down_blocks", + "up_samples": "up_blocks", + "projection_lin": "encoder_hid_proj.projection_linear", + "projection_ln": "encoder_hid_proj.projection_norm", + "feature_pooling": "add_time_condition", + "to_query": "to_q", + "to_key": "to_k", + "to_value": "to_v", + "output_layer": "to_out.0", + "self_attention_block": "attentions.0", +} + +DYNAMIC_MAP = { + "resnet_attn_blocks.*.0": "resnets_in.*", + "resnet_attn_blocks.*.1": ("attentions.*", 1), + "resnet_attn_blocks.*.2": "resnets_out.*", +} +# MAPPING = {} + + +def convert_state_dict(unet_state_dict): + """ + Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model. + Args: + unet_model (torch.nn.Module): The original U-Net model. + unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with. + + Returns: + OrderedDict: The converted state dictionary. + """ + # Example of renaming logic (this will vary based on your model's architecture) + converted_state_dict = {} + for key in unet_state_dict: + new_key = key + for pattern, new_pattern in MAPPING.items(): + new_key = new_key.replace(pattern, new_pattern) + + for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items(): + has_matched = False + if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched: + star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1]) + + if isinstance(dyn_new_pattern, tuple): + new_star = star + dyn_new_pattern[-1] + dyn_new_pattern = dyn_new_pattern[0] + else: + new_star = star + + pattern = dyn_pattern.replace("*", str(star)) + new_pattern = dyn_new_pattern.replace("*", str(new_star)) + + new_key = new_key.replace(pattern, new_pattern) + has_matched = True + + converted_state_dict[new_key] = unet_state_dict[key] + + return converted_state_dict + + +def main(model_path, output_path): + # Load your original U-Net model + unet_state_dict = load_file(model_path) + + # Initialize your Kandinsky3UNet model + config = {} + + # Convert the state dict + converted_state_dict = convert_state_dict(unet_state_dict) + + unet = Kandinsky3UNet(config) + unet.load_state_dict(converted_state_dict) + + unet.save_pretrained(output_path) + print(f"Converted model saved to {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format") + parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model") + parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model") + + args = parser.parse_args() + main(args.model_path, args.output_path) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 21e7fbd59f24..8a0dc2b923d3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -79,6 +79,7 @@ "AutoencoderTiny", "ConsistencyDecoderVAE", "ControlNetModel", + "Kandinsky3UNet", "ModelMixin", "MotionAdapter", "MultiAdapter", @@ -214,6 +215,8 @@ "IFPipeline", "IFSuperResolutionPipeline", "ImageTextPipelineOutput", + "Kandinsky3Img2ImgPipeline", + "Kandinsky3Pipeline", "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", "KandinskyImg2ImgPipeline", @@ -446,6 +449,7 @@ AutoencoderTiny, ConsistencyDecoderVAE, ControlNetModel, + Kandinsky3UNet, ModelMixin, MotionAdapter, MultiAdapter, @@ -560,6 +564,8 @@ IFPipeline, IFSuperResolutionPipeline, ImageTextPipelineOutput, + Kandinsky3Img2ImgPipeline, + Kandinsky3Pipeline, KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, KandinskyImg2ImgPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index d45f56d43c32..de2e2848b848 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -36,6 +36,7 @@ _import_structure["unet_2d"] = ["UNet2DModel"] _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"] _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"] + _import_structure["unet_kandi3"] = ["Kandinsky3UNet"] _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["vq_model"] = ["VQModel"] @@ -63,6 +64,7 @@ from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .unet_3d_condition import UNet3DConditionModel + from .unet_kandi3 import Kandinsky3UNet from .unet_motion_model import MotionAdapter, UNetMotionModel from .vq_model import VQModel diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6b86ba66db37..21eb3a32dc09 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -16,7 +16,7 @@ import torch import torch.nn.functional as F -from torch import nn +from torch import einsum, nn from ..utils import USE_PEFT_BACKEND, deprecate, logging from ..utils.import_utils import is_xformers_available @@ -2219,6 +2219,44 @@ def __call__( return hidden_states +# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe +# this way torch.compile and co. will work as well +class Kandi3AttnProcessor: + r""" + Default kandinsky3 proccesor for performing attention-related computations. + """ + + @staticmethod + def _reshape(hid_states, h): + b, n, f = hid_states.shape + d = f // h + return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3) + + def __call__( + self, + attn, + x, + context, + context_mask=None, + ): + query = self._reshape(attn.to_q(x), h=attn.num_heads) + key = self._reshape(attn.to_k(context), h=attn.num_heads) + value = self._reshape(attn.to_v(context), h=attn.num_heads) + + attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key) + + if context_mask is not None: + max_neg_value = -torch.finfo(attention_matrix.dtype).max + context_mask = context_mask.unsqueeze(1).unsqueeze(1) + attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value) + attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1) + + out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value) + out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1) + out = attn.to_out[0](out) + return out + + LORA_ATTENTION_PROCESSORS = ( LoRAAttnProcessor, LoRAAttnProcessor2_0, @@ -2244,6 +2282,7 @@ def __call__( LoRAXFormersAttnProcessor, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, + Kandi3AttnProcessor, ) AttentionProcessor = Union[ diff --git a/src/diffusers/models/unet_kandi3.py b/src/diffusers/models/unet_kandi3.py new file mode 100644 index 000000000000..42e25a942f7d --- /dev/null +++ b/src/diffusers/models/unet_kandi3.py @@ -0,0 +1,589 @@ +import math +from dataclasses import dataclass +from typing import Dict, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput, logging +from .attention_processor import AttentionProcessor, Kandi3AttnProcessor +from .embeddings import TimestepEmbedding +from .modeling_utils import ModelMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class Kandinsky3UNetOutput(BaseOutput): + sample: torch.FloatTensor = None + + +# TODO(Yiyi): This class needs to be removed +def set_default_item(condition, item_1, item_2=None): + if condition: + return item_1 + else: + return item_2 + + +# TODO(Yiyi): This class needs to be removed +def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=torch.nn.Identity, args_2=[], kwargs_2={}): + if condition: + return layer_1(*args_1, **kwargs_1) + else: + return layer_2(*args_2, **kwargs_2) + + +# TODO(Yiyi): This class should be removed and be replaced by Timesteps +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x, type_tensor=None): + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb) + emb = x[:, None] * emb[None, :] + return torch.cat((emb.sin(), emb.cos()), dim=-1) + + +class Kandinsky3EncoderProj(nn.Module): + def __init__(self, encoder_hid_dim, cross_attention_dim): + super().__init__() + self.projection_linear = nn.Linear(encoder_hid_dim, cross_attention_dim, bias=False) + self.projection_norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, x): + x = self.projection_linear(x) + x = self.projection_norm(x) + return x + + +class Kandinsky3UNet(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels: int = 4, + time_embedding_dim: int = 1536, + groups: int = 32, + attention_head_dim: int = 64, + layers_per_block: Union[int, Tuple[int]] = 3, + block_out_channels: Tuple[int] = (384, 768, 1536, 3072), + cross_attention_dim: Union[int, Tuple[int]] = 4096, + encoder_hid_dim: int = 4096, + ): + super().__init__() + + # TOOD(Yiyi): Give better name and put into config for the following 4 parameters + expansion_ratio = 4 + compression_ratio = 2 + add_cross_attention = (False, True, True, True) + add_self_attention = (False, True, True, True) + + out_channels = in_channels + init_channels = block_out_channels[0] // 2 + # TODO(Yiyi): Should be replaced with Timesteps class -> make sure that results are the same + # self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1) + self.time_proj = SinusoidalPosEmb(init_channels) + + self.time_embedding = TimestepEmbedding( + init_channels, + time_embedding_dim, + ) + + self.add_time_condition = Kandinsky3AttentionPooling( + time_embedding_dim, cross_attention_dim, attention_head_dim + ) + + self.conv_in = nn.Conv2d(in_channels, init_channels, kernel_size=3, padding=1) + + self.encoder_hid_proj = Kandinsky3EncoderProj(encoder_hid_dim, cross_attention_dim) + + hidden_dims = [init_channels] + list(block_out_channels) + in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:])) + text_dims = [set_default_item(is_exist, cross_attention_dim) for is_exist in add_cross_attention] + num_blocks = len(block_out_channels) * [layers_per_block] + layer_params = [num_blocks, text_dims, add_self_attention] + rev_layer_params = map(reversed, layer_params) + + cat_dims = [] + self.num_levels = len(in_out_dims) + self.down_blocks = nn.ModuleList([]) + for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate( + zip(in_out_dims, *layer_params) + ): + down_sample = level != (self.num_levels - 1) + cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0)) + self.down_blocks.append( + Kandinsky3DownSampleBlock( + in_dim, + out_dim, + time_embedding_dim, + text_dim, + res_block_num, + groups, + attention_head_dim, + expansion_ratio, + compression_ratio, + down_sample, + self_attention, + ) + ) + + self.up_blocks = nn.ModuleList([]) + for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate( + zip(reversed(in_out_dims), *rev_layer_params) + ): + up_sample = level != 0 + self.up_blocks.append( + Kandinsky3UpSampleBlock( + in_dim, + cat_dims.pop(), + out_dim, + time_embedding_dim, + text_dim, + res_block_num, + groups, + attention_head_dim, + expansion_ratio, + compression_ratio, + up_sample, + self_attention, + ) + ) + + self.conv_norm_out = nn.GroupNorm(groups, init_channels) + self.conv_act_out = nn.SiLU() + self.conv_out = nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(Kandi3AttnProcessor()) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True): + # TODO(Yiyi): Clean up the following variables - these names should not be used + # but instead only the ones that we pass to forward + x = sample + context_mask = encoder_attention_mask + context = encoder_hidden_states + + if not torch.is_tensor(timestep): + dtype = torch.float32 if isinstance(timestep, float) else torch.int32 + timestep = torch.tensor([timestep], dtype=dtype, device=sample.device) + elif len(timestep.shape) == 0: + timestep = timestep[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = timestep.expand(sample.shape[0]) + time_embed_input = self.time_proj(timestep).to(x.dtype) + time_embed = self.time_embedding(time_embed_input) + + context = self.encoder_hid_proj(context) + + if context is not None: + time_embed = self.add_time_condition(time_embed, context, context_mask) + + hidden_states = [] + x = self.conv_in(x) + for level, down_sample in enumerate(self.down_blocks): + x = down_sample(x, time_embed, context, context_mask) + if level != self.num_levels - 1: + hidden_states.append(x) + + for level, up_sample in enumerate(self.up_blocks): + if level != 0: + x = torch.cat([x, hidden_states.pop()], dim=1) + x = up_sample(x, time_embed, context, context_mask) + + x = self.conv_norm_out(x) + x = self.conv_act_out(x) + x = self.conv_out(x) + + if not return_dict: + return (x,) + return Kandinsky3UNetOutput(sample=x) + + +class Kandinsky3UpSampleBlock(nn.Module): + def __init__( + self, + in_channels, + cat_dim, + out_channels, + time_embed_dim, + context_dim=None, + num_blocks=3, + groups=32, + head_dim=64, + expansion_ratio=4, + compression_ratio=2, + up_sample=True, + self_attention=True, + ): + super().__init__() + up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1) + hidden_channels = ( + [(in_channels + cat_dim, in_channels)] + + [(in_channels, in_channels)] * (num_blocks - 2) + + [(in_channels, out_channels)] + ) + attentions = [] + resnets_in = [] + resnets_out = [] + + self.self_attention = self_attention + self.context_dim = context_dim + + attentions.append( + set_default_layer( + self_attention, + Kandinsky3AttentionBlock, + (out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio), + layer_2=nn.Identity, + ) + ) + + for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): + resnets_in.append( + Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution) + ) + attentions.append( + set_default_layer( + context_dim is not None, + Kandinsky3AttentionBlock, + (in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio), + layer_2=nn.Identity, + ) + ) + resnets_out.append( + Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets_in = nn.ModuleList(resnets_in) + self.resnets_out = nn.ModuleList(resnets_out) + + def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None): + for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out): + x = resnet_in(x, time_embed) + if self.context_dim is not None: + x = attention(x, time_embed, context, context_mask, image_mask) + x = resnet_out(x, time_embed) + + if self.self_attention: + x = self.attentions[0](x, time_embed, image_mask=image_mask) + return x + + +class Kandinsky3DownSampleBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + time_embed_dim, + context_dim=None, + num_blocks=3, + groups=32, + head_dim=64, + expansion_ratio=4, + compression_ratio=2, + down_sample=True, + self_attention=True, + ): + super().__init__() + attentions = [] + resnets_in = [] + resnets_out = [] + + self.self_attention = self_attention + self.context_dim = context_dim + + attentions.append( + set_default_layer( + self_attention, + Kandinsky3AttentionBlock, + (in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio), + layer_2=nn.Identity, + ) + ) + + up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]] + hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1) + for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): + resnets_in.append( + Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) + ) + attentions.append( + set_default_layer( + context_dim is not None, + Kandinsky3AttentionBlock, + (out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio), + layer_2=nn.Identity, + ) + ) + resnets_out.append( + Kandinsky3ResNetBlock( + out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets_in = nn.ModuleList(resnets_in) + self.resnets_out = nn.ModuleList(resnets_out) + + def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None): + if self.self_attention: + x = self.attentions[0](x, time_embed, image_mask=image_mask) + + for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out): + x = resnet_in(x, time_embed) + if self.context_dim is not None: + x = attention(x, time_embed, context, context_mask, image_mask) + x = resnet_out(x, time_embed) + return x + + +class Kandinsky3ConditionalGroupNorm(nn.Module): + def __init__(self, groups, normalized_shape, context_dim): + super().__init__() + self.norm = nn.GroupNorm(groups, normalized_shape, affine=False) + self.context_mlp = nn.Sequential(nn.SiLU(), nn.Linear(context_dim, 2 * normalized_shape)) + self.context_mlp[1].weight.data.zero_() + self.context_mlp[1].bias.data.zero_() + + def forward(self, x, context): + context = self.context_mlp(context) + + for _ in range(len(x.shape[2:])): + context = context.unsqueeze(-1) + + scale, shift = context.chunk(2, dim=1) + x = self.norm(x) * (scale + 1.0) + shift + return x + + +# TODO(Yiyi): This class should ideally not even exist, it slows everything needlessly down. I'm pretty +# sure we can delete it and instead just pass an attention_mask +class Attention(nn.Module): + def __init__(self, in_channels, out_channels, context_dim, head_dim=64): + super().__init__() + assert out_channels % head_dim == 0 + self.num_heads = out_channels // head_dim + self.scale = head_dim**-0.5 + + # to_q + self.to_q = nn.Linear(in_channels, out_channels, bias=False) + # to_k + self.to_k = nn.Linear(context_dim, out_channels, bias=False) + # to_v + self.to_v = nn.Linear(context_dim, out_channels, bias=False) + processor = Kandi3AttnProcessor() + self.set_processor(processor) + # to_out + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(out_channels, out_channels, bias=False)) + + def set_processor(self, processor: "AttnProcessor"): # noqa: F821 + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def forward(self, x, context, context_mask=None, image_mask=None): + return self.processor( + self, + x, + context=context, + context_mask=context_mask, + ) + + +class Kandinsky3Block(nn.Module): + def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None): + super().__init__() + self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim) + self.activation = nn.SiLU() + self.up_sample = set_default_layer( + up_resolution is not None and up_resolution, + nn.ConvTranspose2d, + (in_channels, in_channels), + {"kernel_size": 2, "stride": 2}, + ) + padding = int(kernel_size > 1) + self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) + self.down_sample = set_default_layer( + up_resolution is not None and not up_resolution, + nn.Conv2d, + (out_channels, out_channels), + {"kernel_size": 2, "stride": 2}, + ) + + def forward(self, x, time_embed): + x = self.group_norm(x, time_embed) + x = self.activation(x) + x = self.up_sample(x) + x = self.projection(x) + x = self.down_sample(x) + return x + + +class Kandinsky3ResNetBlock(nn.Module): + def __init__( + self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4 * [None] + ): + super().__init__() + kernel_sizes = [1, 3, 3, 1] + hidden_channel = max(in_channels, out_channels) // compression_ratio + hidden_channels = ( + [(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)] + ) + self.resnet_blocks = nn.ModuleList( + [ + Kandinsky3Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution) + for (in_channel, out_channel), kernel_size, up_resolution in zip( + hidden_channels, kernel_sizes, up_resolutions + ) + ] + ) + self.shortcut_up_sample = set_default_layer( + True in up_resolutions, nn.ConvTranspose2d, (in_channels, in_channels), {"kernel_size": 2, "stride": 2} + ) + self.shortcut_projection = set_default_layer( + in_channels != out_channels, nn.Conv2d, (in_channels, out_channels), {"kernel_size": 1} + ) + self.shortcut_down_sample = set_default_layer( + False in up_resolutions, nn.Conv2d, (out_channels, out_channels), {"kernel_size": 2, "stride": 2} + ) + + def forward(self, x, time_embed): + out = x + for resnet_block in self.resnet_blocks: + out = resnet_block(out, time_embed) + + x = self.shortcut_up_sample(x) + x = self.shortcut_projection(x) + x = self.shortcut_down_sample(x) + x = x + out + return x + + +class Kandinsky3AttentionPooling(nn.Module): + def __init__(self, num_channels, context_dim, head_dim=64): + super().__init__() + self.attention = Attention(context_dim, num_channels, context_dim, head_dim) + + def forward(self, x, context, context_mask=None): + context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask) + return x + context.squeeze(1) + + +class Kandinsky3AttentionBlock(nn.Module): + def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4): + super().__init__() + self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) + self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim) + + hidden_channels = expansion_ratio * num_channels + self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) + self.feed_forward = nn.Sequential( + nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False), + nn.SiLU(), + nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False), + ) + + def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None): + height, width = x.shape[-2:] + out = self.in_norm(x, time_embed) + out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1) + context = context if context is not None else out + + if image_mask is not None: + mask_height, mask_width = image_mask.shape[-2:] + kernel_size = (mask_height // height, mask_width // width) + image_mask = F.max_pool2d(image_mask, kernel_size, kernel_size) + image_mask = image_mask.reshape(image_mask.shape[0], -1) + + out = self.attention(out, context, context_mask, image_mask) + out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width) + x = x + out + + out = self.out_norm(x, time_embed) + out = self.feed_forward(out) + x = x + out + return x diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 879bd6d98aa6..78c1b7c6285d 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -110,6 +110,7 @@ "KandinskyV22PriorEmb2EmbPipeline", "KandinskyV22PriorPipeline", ] + _import_structure["kandinsky3"] = ["Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline"] _import_structure["latent_consistency_models"] = [ "LatentConsistencyModelImg2ImgPipeline", "LatentConsistencyModelPipeline", @@ -338,6 +339,10 @@ KandinskyV22PriorEmb2EmbPipeline, KandinskyV22PriorPipeline, ) + from .kandinsky3 import ( + Kandinsky3Img2ImgPipeline, + Kandinsky3Pipeline, + ) from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .latent_diffusion import LDMTextToImagePipeline from .musicldm import MusicLDMPipeline diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 6396bbbbc278..a7c6cd82c8e7 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -42,6 +42,7 @@ KandinskyV22InpaintPipeline, KandinskyV22Pipeline, ) +from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .pixart_alpha import PixArtAlphaPipeline from .stable_diffusion import ( @@ -64,6 +65,7 @@ ("if", IFPipeline), ("kandinsky", KandinskyCombinedPipeline), ("kandinsky22", KandinskyV22CombinedPipeline), + ("kandinsky3", Kandinsky3Pipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline), ("wuerstchen", WuerstchenCombinedPipeline), @@ -79,6 +81,7 @@ ("if", IFImg2ImgPipeline), ("kandinsky", KandinskyImg2ImgCombinedPipeline), ("kandinsky22", KandinskyV22Img2ImgCombinedPipeline), + ("kandinsky3", Kandinsky3Img2ImgPipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline), diff --git a/src/diffusers/pipelines/kandinsky3/__init__.py b/src/diffusers/pipelines/kandinsky3/__init__.py new file mode 100644 index 000000000000..4da3a83c0448 --- /dev/null +++ b/src/diffusers/pipelines/kandinsky3/__init__.py @@ -0,0 +1,49 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["kandinsky3_pipeline"] = ["Kandinsky3Pipeline"] + _import_structure["kandinsky3img2img_pipeline"] = ["Kandinsky3Img2ImgPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .kandinsky3_pipeline import Kandinsky3Pipeline + from .kandinsky3img2img_pipeline import Kandinsky3Img2ImgPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py b/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py new file mode 100644 index 000000000000..8ba1a4f637be --- /dev/null +++ b/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py @@ -0,0 +1,452 @@ +from typing import Callable, List, Optional, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...loaders import LoraLoaderMixin +from ...models import Kandinsky3UNet, VQModel +from ...schedulers import DDPMScheduler +from ...utils import ( + is_accelerate_available, + logging, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): + model_cpu_offload_seq = "text_encoder->unet->movq" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: Kandinsky3UNet, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq + ) + + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + def process_embeds(self, embeddings, attention_mask, cut_context): + if cut_context: + embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0]) + max_seq_length = attention_mask.sum(-1).max() + 1 + embeddings = embeddings[:, :max_seq_length] + attention_mask = attention_mask[:, :max_seq_length] + return embeddings, attention_mask + + @torch.no_grad() + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + _cut_context=False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = 128 + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_embeds, attention_mask = self.process_embeds(prompt_embeds, attention_mask, _cut_context) + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(2) + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + attention_mask = attention_mask.repeat(num_images_per_prompt, 1) + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + if negative_prompt is not None: + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=128, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=negative_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds[:, : prompt_embeds.shape[1]] + negative_attention_mask = negative_attention_mask[:, : prompt_embeds.shape[1]] + negative_prompt_embeds = negative_prompt_embeds * negative_attention_mask.unsqueeze(2) + + else: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_attention_mask = torch.zeros_like(attention_mask) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + if negative_prompt_embeds.shape != prompt_embeds.shape: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_attention_mask = negative_attention_mask.repeat(num_images_per_prompt, 1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + negative_attention_mask = None + return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + num_inference_steps: int = 100, + guidance_scale: float = 3.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + height: Optional[int] = 1024, + width: Optional[int] = 1024, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + latents=None, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 3.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (ฮท) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + """ + cut_context = True + device = self._execution_device + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + _cut_context=cut_context, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool() + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latents + height, width = downscale_height_and_width(height, width, 8) + + latents = self.prepare_latents( + (batch_size * num_images_per_prompt, 4, height, width), + prompt_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + # TODO(Yiyi): Correct the following line and use correctly + # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=attention_mask, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + + noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond + # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type not in ["pt", "np", "pil"]: + raise ValueError( + f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}" + ) + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py b/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py new file mode 100644 index 000000000000..b043110cf1d7 --- /dev/null +++ b/src/diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py @@ -0,0 +1,460 @@ +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import PIL.Image +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...loaders import LoraLoaderMixin +from ...models import Kandinsky3UNet, VQModel +from ...schedulers import DDPMScheduler +from ...utils import ( + is_accelerate_available, + logging, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def downscale_height_and_width(height, width, scale_factor=8): + new_height = height // scale_factor**2 + if height % scale_factor**2 != 0: + new_height += 1 + new_width = width // scale_factor**2 + if width % scale_factor**2 != 0: + new_width += 1 + return new_height * scale_factor, new_width * scale_factor + + +def prepare_image(pil_image): + arr = np.array(pil_image.convert("RGB")) + arr = arr.astype(np.float32) / 127.5 - 1 + arr = np.transpose(arr, [2, 0, 1]) + image = torch.from_numpy(arr).unsqueeze(0) + return image + + +class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): + model_cpu_offload_seq = "text_encoder->unet->movq" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + unet: Kandinsky3UNet, + scheduler: DDPMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def remove_all_hooks(self): + if is_accelerate_available(): + from accelerate.hooks import remove_hook_from_module + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + for model in [self.text_encoder, self.unet]: + if model is not None: + remove_hook_from_module(model, recurse=True) + + self.unet_offload_hook = None + self.text_encoder_offload_hook = None + self.final_offload_hook = None + + def _process_embeds(self, embeddings, attention_mask, cut_context): + # return embeddings, attention_mask + if cut_context: + embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0]) + max_seq_length = attention_mask.sum(-1).max() + 1 + embeddings = embeddings[:, :max_seq_length] + attention_mask = attention_mask[:, :max_seq_length] + return embeddings, attention_mask + + @torch.no_grad() + def encode_prompt( + self, + prompt, + do_classifier_free_guidance=True, + num_images_per_prompt=1, + device=None, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + _cut_context=False, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and negative_prompt is not None: + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + max_length = 128 + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + prompt_embeds, attention_mask = self._process_embeds(prompt_embeds, attention_mask, _cut_context) + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(2) + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + attention_mask = attention_mask.repeat(num_images_per_prompt, 1) + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + if negative_prompt is not None: + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=128, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=negative_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds[:, : prompt_embeds.shape[1]] + negative_attention_mask = negative_attention_mask[:, : prompt_embeds.shape[1]] + negative_prompt_embeds = negative_prompt_embeds * negative_attention_mask.unsqueeze(2) + + else: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_attention_mask = torch.zeros_like(attention_mask) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + if negative_prompt_embeds.shape != prompt_embeds.shape: + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_attention_mask = negative_attention_mask.repeat(num_images_per_prompt, 1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + else: + negative_prompt_embeds = None + negative_attention_mask = None + return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.movq.encode(image).latent_dist.sample(generator) + + init_latents = self.movq.config.scaling_factor * init_latents + + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + + latents = init_latents + + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (ฮท) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to ฮท in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, + strength: float = 0.3, + num_inference_steps: int = 25, + guidance_scale: float = 3.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + latents=None, + ): + cut_context = True + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt( + prompt, + do_classifier_free_guidance, + num_images_per_prompt=num_images_per_prompt, + device=device, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + _cut_context=cut_context, + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool() + if not isinstance(image, list): + image = [image] + if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" + ) + + image = torch.cat([prepare_image(i) for i in image], dim=0) + image = image.to(dtype=prompt_embeds.dtype, device=device) + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + # 5. Prepare latents + latents = self.movq.encode(image)["latents"] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + latents = self.prepare_latents( + latents, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None: + self.text_encoder_offload_hook.offload() + + # 7. Denoising loop + # TODO(Yiyi): Correct the following line and use correctly + # num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=attention_mask, + )[0] + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + + noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type not in ["pt", "np", "pil"]: + raise ValueError( + f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}" + ) + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 090b1081fdaf..360727ab2fc5 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -77,6 +77,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class Kandinsky3UNet(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ModelMixin(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index d6200bcaf122..3386a95eb7d4 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -242,6 +242,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Kandinsky3Img2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class Kandinsky3Pipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class KandinskyCombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/convert_kandinsky3_unet.py b/tests/convert_kandinsky3_unet.py new file mode 100755 index 000000000000..4fe8c54eb7fc --- /dev/null +++ b/tests/convert_kandinsky3_unet.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +import argparse +import fnmatch + +from safetensors.torch import load_file + +from diffusers import Kandinsky3UNet + + +MAPPING = { + "to_time_embed.1": "time_embedding.linear_1", + "to_time_embed.3": "time_embedding.linear_2", + "in_layer": "conv_in", + "out_layer.0": "conv_norm_out", + "out_layer.2": "conv_out", + "down_samples": "down_blocks", + "up_samples": "up_blocks", + "projection_lin": "encoder_hid_proj.projection_linear", + "projection_ln": "encoder_hid_proj.projection_norm", + "feature_pooling": "add_time_condition", + "to_query": "to_q", + "to_key": "to_k", + "to_value": "to_v", + "output_layer": "to_out.0", + "self_attention_block": "attentions.0", +} + +DYNAMIC_MAP = { + "resnet_attn_blocks.*.0": "resnets_in.*", + "resnet_attn_blocks.*.1": ("attentions.*", 1), + "resnet_attn_blocks.*.2": "resnets_out.*", +} +# MAPPING = {} + + +def convert_state_dict(unet_state_dict): + """ + Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model. + Args: + unet_model (torch.nn.Module): The original U-Net model. + unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with. + + Returns: + OrderedDict: The converted state dictionary. + """ + # Example of renaming logic (this will vary based on your model's architecture) + converted_state_dict = {} + for key in unet_state_dict: + new_key = key + for pattern, new_pattern in MAPPING.items(): + new_key = new_key.replace(pattern, new_pattern) + + for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items(): + has_matched = False + if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched: + star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1]) + + if isinstance(dyn_new_pattern, tuple): + new_star = star + dyn_new_pattern[-1] + dyn_new_pattern = dyn_new_pattern[0] + else: + new_star = star + + pattern = dyn_pattern.replace("*", str(star)) + new_pattern = dyn_new_pattern.replace("*", str(new_star)) + + new_key = new_key.replace(pattern, new_pattern) + has_matched = True + + converted_state_dict[new_key] = unet_state_dict[key] + + return converted_state_dict + + +def main(model_path, output_path): + # Load your original U-Net model + unet_state_dict = load_file(model_path) + + # Initialize your Kandinsky3UNet model + config = {} + + # Convert the state dict + converted_state_dict = convert_state_dict(unet_state_dict) + + unet = Kandinsky3UNet(config) + unet.load_state_dict(converted_state_dict) + + unet.save_pretrained(output_path) + print(f"Converted model saved to {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format") + parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model") + parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model") + + args = parser.parse_args() + main(args.model_path, args.output_path) diff --git a/tests/pipelines/kandinsky3/__init__.py b/tests/pipelines/kandinsky3/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/kandinsky3/test_kandinsky3.py b/tests/pipelines/kandinsky3/test_kandinsky3.py new file mode 100644 index 000000000000..65297a36b157 --- /dev/null +++ b/tests/pipelines/kandinsky3/test_kandinsky3.py @@ -0,0 +1,237 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import ( + AutoPipelineForImage2Image, + AutoPipelineForText2Image, + Kandinsky3Pipeline, + Kandinsky3UNet, + VQModel, +) +from diffusers.image_processor import VaeImageProcessor +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusers.utils.testing_utils import ( + enable_full_determinism, + load_image, + require_torch_gpu, + slow, +) + +from ..pipeline_params import ( + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class Kandinsky3PipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Kandinsky3Pipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS + test_xformers_attention = False + + @property + def dummy_movq_kwargs(self): + return { + "block_out_channels": [32, 64], + "down_block_types": ["DownEncoderBlock2D", "AttnDownEncoderBlock2D"], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 1, + "norm_num_groups": 8, + "norm_type": "spatial", + "num_vq_embeddings": 12, + "out_channels": 3, + "up_block_types": [ + "AttnUpDecoderBlock2D", + "UpDecoderBlock2D", + ], + "vq_embed_dim": 4, + } + + @property + def dummy_movq(self): + torch.manual_seed(0) + model = VQModel(**self.dummy_movq_kwargs) + return model + + def get_dummy_components(self, time_cond_proj_dim=None): + torch.manual_seed(0) + unet = Kandinsky3UNet( + in_channels=4, + time_embedding_dim=4, + groups=2, + attention_head_dim=4, + layers_per_block=3, + block_out_channels=(32, 64), + cross_attention_dim=4, + encoder_hid_dim=32, + ) + scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="squaredcos_cap_v2", + clip_sample=True, + thresholding=False, + ) + torch.manual_seed(0) + movq = self.dummy_movq + torch.manual_seed(0) + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + + torch.manual_seed(0) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "unet": unet, + "scheduler": scheduler, + "movq": movq, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "np", + "width": 16, + "height": 16, + } + return inputs + + def test_kandinsky3(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 16, 16, 3) + + expected_slice = np.array([0.3768, 0.4373, 0.4865, 0.4890, 0.4299, 0.5122, 0.4921, 0.4924, 0.5599]) + + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + + def test_float16_inference(self): + super().test_float16_inference(expected_max_diff=1e-1) + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=1e-2) + + def test_model_cpu_offload_forward_pass(self): + # TODO(Yiyi) - this test should work, skipped for time reasons for now + pass + + +@slow +@require_torch_gpu +class Kandinsky3PipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_kandinskyV3(self): + pipe = AutoPipelineForText2Image.from_pretrained( + "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background." + + generator = torch.Generator(device="cpu").manual_seed(0) + + image = pipe(prompt, num_inference_steps=25, generator=generator).images[0] + + assert image.size == (1024, 1024) + + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png" + ) + + image_processor = VaeImageProcessor() + + image_np = image_processor.pil_to_numpy(image) + expected_image_np = image_processor.pil_to_numpy(expected_image) + + self.assertTrue(np.allclose(image_np, expected_image_np, atol=5e-2)) + + def test_kandinskyV3_img2img(self): + pipe = AutoPipelineForImage2Image.from_pretrained( + "kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16 + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png" + ) + w, h = 512, 512 + image = image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) + prompt = "A painting of the inside of a subway train with tiny raccoons." + + image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0] + + assert image.size == (512, 512) + + expected_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/i2i.png" + ) + + image_processor = VaeImageProcessor() + + image_np = image_processor.pil_to_numpy(image) + expected_image_np = image_processor.pil_to_numpy(expected_image) + + self.assertTrue(np.allclose(image_np, expected_image_np, atol=5e-2)) From 2a7f43a73bda387385a47a15d7b6fe9be9c65eb2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Nov 2023 17:09:26 +0000 Subject: [PATCH 06/23] correct num inference steps --- src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py b/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py index 8ba1a4f637be..f116fb7894f0 100644 --- a/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +++ b/src/diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py @@ -267,7 +267,7 @@ def check_inputs( def __call__( self, prompt: Union[str, List[str]] = None, - num_inference_steps: int = 100, + num_inference_steps: int = 25, guidance_scale: float = 3.0, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, From 6d2e19f7466b70574209d3da4488e16610c4fac6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Nov 2023 06:13:20 +0100 Subject: [PATCH 07/23] [Examples] Allow downloading variant model files (#5531) * add variant * add variant * Apply suggestions from code review * reformat * fix: textual_inversion.py * fix: variant in model_info --------- Co-authored-by: sayakpaul --- examples/controlnet/train_controlnet.py | 20 +++++--- examples/controlnet/train_controlnet_sdxl.py | 29 ++++++++---- .../train_custom_diffusion.py | 18 ++++++-- examples/dreambooth/train_dreambooth.py | 20 +++++--- examples/dreambooth/train_dreambooth_flax.py | 10 +++- examples/dreambooth/train_dreambooth_lora.py | 16 +++++-- .../dreambooth/train_dreambooth_lora_sdxl.py | 46 +++++++++++++++---- .../train_instruct_pix2pix.py | 14 +++++- .../train_instruct_pix2pix_sdxl.py | 25 ++++++++-- .../t2i_adapter/train_t2i_adapter_sdxl.py | 24 ++++++++-- examples/text_to_image/train_text_to_image.py | 14 ++++-- .../text_to_image/train_text_to_image_flax.py | 6 +++ .../text_to_image/train_text_to_image_lora.py | 15 ++++-- .../train_text_to_image_lora_sdxl.py | 34 +++++++++++--- .../text_to_image/train_text_to_image_sdxl.py | 39 ++++++++++++---- .../textual_inversion/textual_inversion.py | 13 +++++- 16 files changed, 266 insertions(+), 77 deletions(-) diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 63b6767a6f8f..8dee7c33eac6 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -86,6 +86,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler controlnet=controlnet, safety_checker=None, revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) @@ -249,10 +250,13 @@ def parse_args(input_args=None): type=str, default=None, required=False, - help=( - "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" - " float32 precision." - ), + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", ) parser.add_argument( "--tokenizer_name", @@ -767,11 +771,13 @@ def main(args): # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) if args.controlnet_model_name_or_path: diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index b4fa96dae8ff..41a29c3945ab 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -74,6 +74,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step) unet=unet, controlnet=controlnet, revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) @@ -243,15 +244,18 @@ def parse_args(input_args=None): help="Path to pretrained controlnet model or model identifier from huggingface.co/models." " If not specified controlnet weights are initialized from unet.", ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) parser.add_argument( "--revision", type=str, default=None, required=False, - help=( - "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" - " float32 precision." - ), + help="Revision of pretrained model identifier from huggingface.co/models.", ) parser.add_argument( "--tokenizer_name", @@ -793,10 +797,16 @@ def main(args): # Load the tokenizers tokenizer_one = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, ) tokenizer_two = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, ) # import correct text encoder classes @@ -810,10 +820,10 @@ def main(args): # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) vae_path = ( args.pretrained_model_name_or_path @@ -824,9 +834,10 @@ def main(args): vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision, + variant=args.variant, ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) if args.controlnet_model_name_or_path: diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index d7f78841a81a..c619a46dd99d 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -332,6 +332,12 @@ def parse_args(input_args=None): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) parser.add_argument( "--tokenizer_name", type=str, @@ -740,6 +746,7 @@ def main(args): torch_dtype=torch_dtype, safety_checker=None, revision=args.revision, + variant=args.variant, ) pipeline.set_progress_bar_config(disable=True) @@ -801,11 +808,13 @@ def main(args): # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) # Adding a modifier token which is optimized #### @@ -1229,6 +1238,7 @@ def main(args): text_encoder=accelerator.unwrap_model(text_encoder), tokenizer=tokenizer, revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) @@ -1278,7 +1288,7 @@ def main(args): # Final inference # Load previous pipeline pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype + args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype ) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 92b57b728673..41854501144b 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -139,6 +139,7 @@ def log_validation( text_encoder=text_encoder, unet=accelerator.unwrap_model(unet), revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, **pipeline_args, ) @@ -239,10 +240,13 @@ def parse_args(input_args=None): type=str, default=None, required=False, - help=( - "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" - " float32 precision." - ), + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", ) parser.add_argument( "--tokenizer_name", @@ -859,6 +863,7 @@ def main(args): torch_dtype=torch_dtype, safety_checker=None, revision=args.revision, + variant=args.variant, ) pipeline.set_progress_bar_config(disable=True) @@ -912,18 +917,18 @@ def main(args): # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) if model_has_vae(args): vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant ) else: vae = None unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format @@ -1379,6 +1384,7 @@ def compute_text_embeddings(prompt): args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), revision=args.revision, + variant=args.variant, **pipeline_args, ) diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 5e8c385133e2..680c9dffdfcb 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -460,7 +460,10 @@ def collate_fn(examples): # Load models and create wrapper for stable diffusion text_encoder = FlaxCLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision + args.pretrained_model_name_or_path, + subfolder="text_encoder", + dtype=weight_dtype, + revision=args.revision, ) vae, vae_params = FlaxAutoencoderKL.from_pretrained( vae_arg, @@ -468,7 +471,10 @@ def collate_fn(examples): **vae_kwargs, ) unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision + args.pretrained_model_name_or_path, + subfolder="unet", + dtype=weight_dtype, + revision=args.revision, ) # Optimization diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index b82dfa38c172..3ba775b543d8 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -183,6 +183,12 @@ def parse_args(input_args=None): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) parser.add_argument( "--tokenizer_name", type=str, @@ -750,6 +756,7 @@ def main(args): torch_dtype=torch_dtype, safety_checker=None, revision=args.revision, + variant=args.variant, ) pipeline.set_progress_bar_config(disable=True) @@ -803,11 +810,11 @@ def main(args): # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) try: vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant ) except OSError: # IF does not have a VAE so let's just set it to None @@ -815,7 +822,7 @@ def main(args): vae = None unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) # We only train the additional adapter LoRA layers @@ -1310,6 +1317,7 @@ def compute_text_embeddings(prompt): unet=accelerator.unwrap_model(unet), text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder), revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) @@ -1395,7 +1403,7 @@ def compute_text_embeddings(prompt): # Final inference # Load previous pipeline pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype + args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype ) # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index f4e7887c1c13..bbe8dab731e9 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -204,6 +204,12 @@ def parse_args(input_args=None): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) parser.add_argument( "--dataset_name", type=str, @@ -877,6 +883,7 @@ def main(args): args.pretrained_model_name_or_path, torch_dtype=torch_dtype, revision=args.revision, + variant=args.variant, ) pipeline.set_progress_bar_config(disable=True) @@ -915,10 +922,16 @@ def main(args): # Load the tokenizers tokenizer_one = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, ) tokenizer_two = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, ) # import correct text encoder classes @@ -932,10 +945,10 @@ def main(args): # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) vae_path = ( args.pretrained_model_name_or_path @@ -943,10 +956,13 @@ def main(args): else args.pretrained_vae_model_name_or_path ) vae = AutoencoderKL.from_pretrained( - vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) # We only train the additional adapter LoRA layers @@ -1571,10 +1587,16 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # create pipeline if not args.train_text_encoder: text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant, ) text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + args.pretrained_model_name_or_path, + subfolder="text_encoder_2", + revision=args.revision, + variant=args.variant, ) pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -1583,6 +1605,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_2=accelerator.unwrap_model(text_encoder_two), unet=accelerator.unwrap_model(unet), revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) @@ -1660,10 +1683,15 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, ) # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index b9b1c9cc5b3b..2766e4c99086 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -78,6 +78,12 @@ def parse_args(): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) parser.add_argument( "--dataset_name", type=str, @@ -435,9 +441,11 @@ def main(): args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision ) text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) @@ -915,6 +923,7 @@ def collate_fn(examples): text_encoder=accelerator.unwrap_model(text_encoder), vae=accelerator.unwrap_model(vae), revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) @@ -966,6 +975,7 @@ def collate_fn(examples): vae=accelerator.unwrap_model(vae), unet=unet, revision=args.revision, + variant=args.variant, ) pipeline.save_pretrained(args.output_dir) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index 6b503cb29275..9b57b5eb08f9 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -118,6 +118,12 @@ def parse_args(): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) parser.add_argument( "--dataset_name", type=str, @@ -484,9 +490,10 @@ def main(): vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision, + variant=args.variant, ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) # InstructPix2Pix uses an additional image for conditioning. To accommodate that, @@ -695,10 +702,16 @@ def preprocess_images(examples): # Load scheduler, tokenizer and models. tokenizer_1 = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, ) tokenizer_2 = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, ) text_encoder_cls_1 = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) text_encoder_cls_2 = import_model_class_from_model_name_or_path( @@ -708,10 +721,10 @@ def preprocess_images(examples): # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder_1 = text_encoder_cls_1.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_2 = text_encoder_cls_2.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) # We ALWAYS pre-compute the additional condition embeddings needed for SDXL @@ -1109,6 +1122,7 @@ def collate_fn(examples): tokenizer_2=tokenizer_2, vae=vae, revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) @@ -1176,6 +1190,7 @@ def collate_fn(examples): vae=vae, unet=unet, revision=args.revision, + variant=args.variant, ) pipeline.save_pretrained(args.output_dir) diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py index d1c9113bbd9d..f8e58bdb80fa 100644 --- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py +++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py @@ -85,6 +85,7 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step): unet=unet, adapter=adapter, revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) @@ -262,6 +263,12 @@ def parse_args(input_args=None): " float32 precision." ), ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) parser.add_argument( "--tokenizer_name", type=str, @@ -812,10 +819,16 @@ def main(args): # Load the tokenizers tokenizer_one = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, ) tokenizer_two = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, ) # import correct text encoder classes @@ -829,10 +842,10 @@ def main(args): # Load scheduler and models noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) vae_path = ( args.pretrained_model_name_or_path @@ -843,9 +856,10 @@ def main(args): vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision, + variant=args.variant, ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) if args.adapter_model_name_or_path: diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 628a0c9d7d96..9a5482054939 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -148,6 +148,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight unet=accelerator.unwrap_model(unet), safety_checker=None, revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) @@ -209,6 +210,12 @@ def parse_args(): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) parser.add_argument( "--dataset_name", type=str, @@ -567,10 +574,10 @@ def deepspeed_zero_init_disabled_context_manager(): # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. with ContextManagers(deepspeed_zero_init_disabled_context_manager()): text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant ) unet = UNet2DConditionModel.from_pretrained( @@ -585,7 +592,7 @@ def deepspeed_zero_init_disabled_context_manager(): # Create EMA for the unet. if args.use_ema: ema_unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) @@ -1026,6 +1033,7 @@ def collate_fn(examples): vae=vae, unet=unet, revision=args.revision, + variant=args.variant, ) pipeline.save_pretrained(args.output_dir) diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index e62d03c730b1..aad29d1f565c 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -54,6 +54,12 @@ def parse_args(): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) parser.add_argument( "--dataset_name", type=str, diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index b7309196dec8..7d731c994bdd 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -130,6 +130,12 @@ def parse_args(): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) parser.add_argument( "--dataset_name", type=str, @@ -454,9 +460,11 @@ def main(): text_encoder = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) # freeze parameters of models to save more memory unet.requires_grad_(False) @@ -881,6 +889,7 @@ def collate_fn(examples): args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) @@ -937,7 +946,7 @@ def collate_fn(examples): # Final inference # Load previous pipeline pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype + args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype ) pipeline = pipeline.to(accelerator.device) diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 96bfe9e16783..b69a85e4f463 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -180,6 +180,12 @@ def parse_args(input_args=None): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) parser.add_argument( "--dataset_name", type=str, @@ -570,10 +576,16 @@ def main(args): # Load the tokenizers tokenizer_one = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, ) tokenizer_two = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, ) # import correct text encoder classes @@ -587,10 +599,10 @@ def main(args): # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) vae_path = ( args.pretrained_model_name_or_path @@ -598,10 +610,13 @@ def main(args): else args.pretrained_vae_model_name_or_path ) vae = AutoencoderKL.from_pretrained( - vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) # We only train the additional adapter LoRA layers @@ -1176,6 +1191,7 @@ def compute_time_ids(original_size, crops_coords_top_left): text_encoder_2=accelerator.unwrap_model(text_encoder_two), unet=accelerator.unwrap_model(unet), revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) @@ -1241,7 +1257,11 @@ def compute_time_ids(original_size, crops_coords_top_left): # Final inference # Load previous pipeline pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, ) pipeline = pipeline.to(accelerator.device) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 041464e701cc..ee15e6f7def6 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -148,6 +148,12 @@ def parse_args(input_args=None): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) parser.add_argument( "--dataset_name", type=str, @@ -618,10 +624,16 @@ def main(args): # Load the tokenizers tokenizer_one = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, ) tokenizer_two = AutoTokenizer.from_pretrained( - args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, ) # import correct text encoder classes @@ -636,10 +648,10 @@ def main(args): noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # Check for terminal SNR in combination with SNR Gamma text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) vae_path = ( args.pretrained_model_name_or_path @@ -647,10 +659,13 @@ def main(args): else args.pretrained_vae_model_name_or_path ) vae = AutoencoderKL.from_pretrained( - vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) # Freeze vae and text encoders. @@ -677,7 +692,7 @@ def main(args): # Create EMA for the unet. if args.use_ema: ema_unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) @@ -1145,12 +1160,14 @@ def compute_time_ids(original_size, crops_coords_top_left): vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision, + variant=args.variant, ) pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, unet=accelerator.unwrap_model(unet), revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) if args.prediction_type is not None: @@ -1198,10 +1215,16 @@ def compute_time_ids(original_size, crops_coords_top_left): vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype + args.pretrained_model_name_or_path, + unet=unet, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, ) if args.prediction_type is not None: scheduler_args = {"prediction_type": args.prediction_type} diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 8ce998aab1fb..8e932add92af 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -126,6 +126,7 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight vae=vae, safety_checker=None, revision=args.revision, + variant=args.variant, torch_dtype=weight_dtype, ) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) @@ -206,6 +207,12 @@ def parse_args(): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) parser.add_argument( "--tokenizer_name", type=str, @@ -624,9 +631,11 @@ def main(): text_encoder = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant + ) unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant ) # Add the placeholder token in tokenizer From 7d6f30e89ba3460dd26235c298c54d2ddb9d1590 Mon Sep 17 00:00:00 2001 From: Junsong Chen Date: Mon, 27 Nov 2023 02:05:35 -0600 Subject: [PATCH 08/23] [Fix: pixart-alpha] random 512px resolution bug (#5842) * [Fix: pixart-alpha] add ASPECT_RATIO_512_BIN in use_resolution_binning for random 512px image generation. * add slow test file for 512px generation without resolution binning * fix: slow tests for resolution binning. --------- Co-authored-by: jschen Co-authored-by: Patrick von Platen Co-authored-by: Sayak Paul --- .../pixart_alpha/pipeline_pixart_alpha.py | 41 +++++++++- tests/pipelines/pixart/test_pixart.py | 77 ++++++++++++++++--- 2 files changed, 105 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index ccb308f8780a..478ceb9b919f 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -97,6 +97,42 @@ "4.0": [2048.0, 512.0], } +ASPECT_RATIO_512_BIN = { + "0.25": [256.0, 1024.0], + "0.28": [256.0, 928.0], + "0.32": [288.0, 896.0], + "0.33": [288.0, 864.0], + "0.35": [288.0, 832.0], + "0.4": [320.0, 800.0], + "0.42": [320.0, 768.0], + "0.48": [352.0, 736.0], + "0.5": [352.0, 704.0], + "0.52": [352.0, 672.0], + "0.57": [384.0, 672.0], + "0.6": [384.0, 640.0], + "0.68": [416.0, 608.0], + "0.72": [416.0, 576.0], + "0.78": [448.0, 576.0], + "0.82": [448.0, 544.0], + "0.88": [480.0, 544.0], + "0.94": [480.0, 512.0], + "1.0": [512.0, 512.0], + "1.07": [512.0, 480.0], + "1.13": [544.0, 480.0], + "1.21": [544.0, 448.0], + "1.29": [576.0, 448.0], + "1.38": [576.0, 416.0], + "1.46": [608.0, 416.0], + "1.67": [640.0, 384.0], + "1.75": [672.0, 384.0], + "2.0": [704.0, 352.0], + "2.09": [736.0, 352.0], + "2.4": [768.0, 320.0], + "2.5": [800.0, 320.0], + "3.0": [864.0, 288.0], + "4.0": [1024.0, 256.0], +} + class PixArtAlphaPipeline(DiffusionPipeline): r""" @@ -691,8 +727,11 @@ def __call__( height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor if use_resolution_binning: + aspect_ratio_bin = ( + ASPECT_RATIO_1024_BIN if self.transformer.config.sample_size == 128 else ASPECT_RATIO_512_BIN + ) orig_height, orig_width = height, width - height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN) + height, width = self.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) self.check_inputs( prompt, diff --git a/tests/pipelines/pixart/test_pixart.py b/tests/pipelines/pixart/test_pixart.py index b2806a5c1c99..eced49e04261 100644 --- a/tests/pipelines/pixart/test_pixart.py +++ b/tests/pipelines/pixart/test_pixart.py @@ -320,6 +320,10 @@ def test_inference_batch_single_identical(self): @slow @require_torch_gpu class PixArtAlphaPipelineIntegrationTests(unittest.TestCase): + ckpt_id_1024 = "PixArt-alpha/PixArt-XL-2-1024-MS" + ckpt_id_512 = "PixArt-alpha/PixArt-XL-2-512x512" + prompt = "A small cactus with a happy face in the Sahara desert." + def tearDown(self): super().tearDown() gc.collect() @@ -328,10 +332,10 @@ def tearDown(self): def test_pixart_1024_fast(self): generator = torch.manual_seed(0) - pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) + pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_1024, torch_dtype=torch.float16) pipe.enable_model_cpu_offload() - prompt = "A small cactus with a happy face in the Sahara desert." + prompt = self.prompt image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images @@ -345,10 +349,10 @@ def test_pixart_1024_fast(self): def test_pixart_512_fast(self): generator = torch.manual_seed(0) - pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-512x512", torch_dtype=torch.float16) + pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.float16) pipe.enable_model_cpu_offload() - prompt = "A small cactus with a happy face in the Sahara desert." + prompt = self.prompt image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images @@ -362,9 +366,9 @@ def test_pixart_512_fast(self): def test_pixart_1024(self): generator = torch.manual_seed(0) - pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) + pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_1024, torch_dtype=torch.float16) pipe.enable_model_cpu_offload() - prompt = "A small cactus with a happy face in the Sahara desert." + prompt = self.prompt image = pipe(prompt, generator=generator, output_type="np").images @@ -378,10 +382,10 @@ def test_pixart_1024(self): def test_pixart_512(self): generator = torch.manual_seed(0) - pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-512x512", torch_dtype=torch.float16) + pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.float16) pipe.enable_model_cpu_offload() - prompt = "A small cactus with a happy face in the Sahara desert." + prompt = self.prompt image = pipe(prompt, generator=generator, output_type="np").images @@ -395,17 +399,66 @@ def test_pixart_512(self): def test_pixart_1024_without_resolution_binning(self): generator = torch.manual_seed(0) - pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) + pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_1024, torch_dtype=torch.float16) pipe.enable_model_cpu_offload() - prompt = "A small cactus with a happy face in the Sahara desert." + prompt = self.prompt + height, width = 1024, 768 + num_inference_steps = 10 + + image = pipe( + prompt, + height=height, + width=width, + generator=generator, + num_inference_steps=num_inference_steps, + output_type="np", + ).images + image_slice = image[0, -3:, -3:, -1] + + generator = torch.manual_seed(0) + no_res_bin_image = pipe( + prompt, + height=height, + width=width, + generator=generator, + num_inference_steps=num_inference_steps, + output_type="np", + use_resolution_binning=False, + ).images + no_res_bin_image_slice = no_res_bin_image[0, -3:, -3:, -1] - image = pipe(prompt, generator=generator, num_inference_steps=5, output_type="np").images + assert not np.allclose(image_slice, no_res_bin_image_slice, atol=1e-4, rtol=1e-4) + + def test_pixart_512_without_resolution_binning(self): + generator = torch.manual_seed(0) + + pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.float16) + pipe.enable_model_cpu_offload() + + prompt = self.prompt + height, width = 512, 768 + num_inference_steps = 10 + + image = pipe( + prompt, + height=height, + width=width, + generator=generator, + num_inference_steps=num_inference_steps, + output_type="np", + ).images image_slice = image[0, -3:, -3:, -1] generator = torch.manual_seed(0) no_res_bin_image = pipe( - prompt, generator=generator, num_inference_steps=5, output_type="np", use_resolution_binning=False + prompt, + height=height, + width=width, + generator=generator, + num_inference_steps=num_inference_steps, + output_type="np", + use_resolution_binning=False, ).images no_res_bin_image_slice = no_res_bin_image[0, -3:, -3:, -1] From 3f7c3511dcc95e5bb9fd53399dfc4eb655e1d6fd Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 27 Nov 2023 16:21:12 +0530 Subject: [PATCH 09/23] [Core] add support for gradient checkpointing in transformer_2d (#5943) add support for gradient checkpointing in transformer_2d --- src/diffusers/models/transformer_2d.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index 24abf54d6da7..3aecc43f0f5b 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -20,7 +20,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..models.embeddings import ImagePositionalEmbeddings -from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate +from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version from .attention import BasicTransformerBlock from .embeddings import CaptionProjection, PatchEmbed from .lora import LoRACompatibleConv, LoRACompatibleLinear @@ -70,6 +70,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin): Configure if the `TransformerBlocks` attention should contain a bias parameter. """ + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, @@ -237,6 +239,10 @@ def __init__( self.gradient_checkpointing = False + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + def forward( self, hidden_states: torch.Tensor, @@ -360,8 +366,19 @@ def forward( for block in self.transformer_blocks: if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( - block, + create_custom_forward(block), hidden_states, attention_mask, encoder_hidden_states, @@ -369,7 +386,7 @@ def forward( timestep, cross_attention_kwargs, class_labels, - use_reentrant=False, + **ckpt_kwargs, ) else: hidden_states = block( From 9c357bda3f0982889d59f7053de38de0dc8038ef Mon Sep 17 00:00:00 2001 From: Aryan V S Date: Mon, 27 Nov 2023 17:03:02 +0530 Subject: [PATCH 10/23] Deprecate KarrasVeScheduler and ScoreSdeVpScheduler (#5269) * deprecated: KarrasVeScheduler, ScoreSdeVpScheduler * delete tests relevant to deprecated schedulers * chore: run make style * fix: import error caused due to incorrect _import_structure after deprecation * fix: ScoreSdeVpScheduler was not importable from diffusers * remove import added by assumption * Update src/diffusers/schedulers/__init__.py as suggested by @patrickvonplaten Co-authored-by: Patrick von Platen * make it a part deprecated * Apply suggestions from code review Co-authored-by: Patrick von Platen * Fix * fix * fix doc * fix doc....again....... * remove karras_ve test folder Co-Authored-By: YiYi Xu --------- Co-authored-by: Patrick von Platen Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu Co-authored-by: yiyixuxu --- docs/source/en/api/schedulers/score_sde_vp.md | 2 +- .../en/api/schedulers/stochastic_karras_ve.md | 2 +- src/diffusers/schedulers/__init__.py | 6 +- .../schedulers/deprecated/__init__.py | 50 +++++++++++ .../{ => deprecated}/scheduling_karras_ve.py | 8 +- .../{ => deprecated}/scheduling_sde_vp.py | 6 +- tests/pipelines/karras_ve/__init__.py | 0 tests/pipelines/karras_ve/test_karras_ve.py | 86 ------------------- 8 files changed, 61 insertions(+), 99 deletions(-) create mode 100644 src/diffusers/schedulers/deprecated/__init__.py rename src/diffusers/schedulers/{ => deprecated}/scheduling_karras_ve.py (98%) rename src/diffusers/schedulers/{ => deprecated}/scheduling_sde_vp.py (96%) delete mode 100644 tests/pipelines/karras_ve/__init__.py delete mode 100644 tests/pipelines/karras_ve/test_karras_ve.py diff --git a/docs/source/en/api/schedulers/score_sde_vp.md b/docs/source/en/api/schedulers/score_sde_vp.md index 204cba877722..85da7e8ed539 100644 --- a/docs/source/en/api/schedulers/score_sde_vp.md +++ b/docs/source/en/api/schedulers/score_sde_vp.md @@ -25,4 +25,4 @@ The abstract from the paper is: ## ScoreSdeVpScheduler -[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler +[[autodoc]] schedulers.deprecated.scheduling_sde_vp.ScoreSdeVpScheduler diff --git a/docs/source/en/api/schedulers/stochastic_karras_ve.md b/docs/source/en/api/schedulers/stochastic_karras_ve.md index eb954d7e5e7b..1bfe4e52e514 100644 --- a/docs/source/en/api/schedulers/stochastic_karras_ve.md +++ b/docs/source/en/api/schedulers/stochastic_karras_ve.md @@ -18,4 +18,4 @@ specific language governing permissions and limitations under the License. [[autodoc]] KarrasVeScheduler ## KarrasVeOutput -[[autodoc]] schedulers.scheduling_karras_ve.KarrasVeOutput +[[autodoc]] schedulers.deprecated.scheduling_karras_ve.KarrasVeOutput \ No newline at end of file diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 5e5102e589d4..40c435dd5637 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -38,6 +38,7 @@ _dummy_modules.update(get_objects_from_module(dummy_pt_objects)) else: + _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"] _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] _import_structure["scheduling_ddim"] = ["DDIMScheduler"] @@ -56,12 +57,10 @@ _import_structure["scheduling_ipndm"] = ["IPNDMScheduler"] _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"] _import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"] - _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"] _import_structure["scheduling_lcm"] = ["LCMScheduler"] _import_structure["scheduling_pndm"] = ["PNDMScheduler"] _import_structure["scheduling_repaint"] = ["RePaintScheduler"] _import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"] - _import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"] _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"] _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"] _import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"] @@ -129,6 +128,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_pt_objects import * # noqa F403 else: + from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler @@ -147,12 +147,10 @@ from .scheduling_ipndm import IPNDMScheduler from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler - from .scheduling_karras_ve import KarrasVeScheduler from .scheduling_lcm import LCMScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_repaint import RePaintScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler - from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_unclip import UnCLIPScheduler from .scheduling_unipc_multistep import UniPCMultistepScheduler from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin diff --git a/src/diffusers/schedulers/deprecated/__init__.py b/src/diffusers/schedulers/deprecated/__init__.py new file mode 100644 index 000000000000..786707f45206 --- /dev/null +++ b/src/diffusers/schedulers/deprecated/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_pt_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_pt_objects)) +else: + _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"] + _import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ..utils.dummy_pt_objects import * # noqa F403 + else: + from .scheduling_karras_ve import KarrasVeScheduler + from .scheduling_sde_vp import ScoreSdeVpScheduler + + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py similarity index 98% rename from src/diffusers/schedulers/scheduling_karras_ve.py rename to src/diffusers/schedulers/deprecated/scheduling_karras_ve.py index 462169b633de..97466ecf8153 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py @@ -19,10 +19,10 @@ import numpy as np import torch -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from ..utils.torch_utils import randn_tensor -from .scheduling_utils import SchedulerMixin +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ...utils.torch_utils import randn_tensor +from ..scheduling_utils import SchedulerMixin @dataclass diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/deprecated/scheduling_sde_vp.py similarity index 96% rename from src/diffusers/schedulers/scheduling_sde_vp.py rename to src/diffusers/schedulers/deprecated/scheduling_sde_vp.py index 177dcbbfaba9..2d0e11378cca 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/deprecated/scheduling_sde_vp.py @@ -19,9 +19,9 @@ import torch -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils.torch_utils import randn_tensor -from .scheduling_utils import SchedulerMixin +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.torch_utils import randn_tensor +from ..scheduling_utils import SchedulerMixin class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): diff --git a/tests/pipelines/karras_ve/__init__.py b/tests/pipelines/karras_ve/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/tests/pipelines/karras_ve/test_karras_ve.py b/tests/pipelines/karras_ve/test_karras_ve.py deleted file mode 100644 index 228d65e508c9..000000000000 --- a/tests/pipelines/karras_ve/test_karras_ve.py +++ /dev/null @@ -1,86 +0,0 @@ -# coding=utf-8 -# Copyright 2023 HuggingFace Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -import torch - -from diffusers import KarrasVePipeline, KarrasVeScheduler, UNet2DModel -from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch, torch_device - - -enable_full_determinism() - - -class KarrasVePipelineFastTests(unittest.TestCase): - @property - def dummy_uncond_unet(self): - torch.manual_seed(0) - model = UNet2DModel( - block_out_channels=(32, 64), - layers_per_block=2, - sample_size=32, - in_channels=3, - out_channels=3, - down_block_types=("DownBlock2D", "AttnDownBlock2D"), - up_block_types=("AttnUpBlock2D", "UpBlock2D"), - ) - return model - - def test_inference(self): - unet = self.dummy_uncond_unet - scheduler = KarrasVeScheduler() - - pipe = KarrasVePipeline(unet=unet, scheduler=scheduler) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - generator = torch.manual_seed(0) - image = pipe(num_inference_steps=2, generator=generator, output_type="numpy").images - - generator = torch.manual_seed(0) - image_from_tuple = pipe(num_inference_steps=2, generator=generator, output_type="numpy", return_dict=False)[0] - - image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 - - -@nightly -@require_torch -class KarrasVePipelineIntegrationTests(unittest.TestCase): - def test_inference(self): - model_id = "google/ncsnpp-celebahq-256" - model = UNet2DModel.from_pretrained(model_id) - scheduler = KarrasVeScheduler() - - pipe = KarrasVePipeline(unet=model, scheduler=scheduler) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - generator = torch.manual_seed(0) - image = pipe(num_inference_steps=20, generator=generator, output_type="numpy").images - - image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 256, 256, 3) - expected_slice = np.array([0.578, 0.5811, 0.5924, 0.5809, 0.587, 0.5886, 0.5861, 0.5802, 0.586]) - - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 From 67d070749ae393a234470b6ef653821bb4f02cc6 Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Mon, 27 Nov 2023 03:39:14 -0800 Subject: [PATCH 11/23] Add Custom Timesteps Support to LCMScheduler and Supported Pipelines (#5874) * Add custom timesteps support to LCMScheduler. * Add custom timesteps support to StableDiffusionPipeline. * Add custom timesteps support to StableDiffusionXLPipeline. * Add custom timesteps support to remaining Stable Diffusion pipelines which support LCMScheduler (img2img, inpaint). * Add custom timesteps support to remaining Stable Diffusion XL pipelines which support LCMScheduler (img2img, inpaint). * Add custom timesteps support to StableDiffusionControlNetPipeline. * Add custom timesteps support to T21 Stable Diffusion (XL) Adapters. * Clean up Stable Diffusion inpaint tests. * Manually add support for custom timesteps to AltDiffusion pipelines since make fix-copies doesn't appear to work correctly (it deletes the whole pipeline). * make style * Refactor pipeline timestep handling into the retrieve_timesteps function. --- .../alt_diffusion/pipeline_alt_diffusion.py | 53 ++++++- .../pipeline_alt_diffusion_img2img.py | 52 ++++++- .../controlnet/pipeline_controlnet.py | 53 ++++++- .../pipeline_latent_consistency_img2img.py | 60 +++++++- .../pipeline_latent_consistency_text2img.py | 55 ++++++- .../pipeline_stable_diffusion.py | 52 ++++++- .../pipeline_stable_diffusion_img2img.py | 52 ++++++- .../pipeline_stable_diffusion_inpaint.py | 52 ++++++- .../pipeline_stable_diffusion_xl.py | 54 ++++++- .../pipeline_stable_diffusion_xl_img2img.py | 52 ++++++- .../pipeline_stable_diffusion_xl_inpaint.py | 52 ++++++- .../pipeline_stable_diffusion_adapter.py | 53 ++++++- .../pipeline_stable_diffusion_xl_adapter.py | 54 ++++++- src/diffusers/schedulers/scheduling_lcm.py | 143 ++++++++++++++---- tests/pipelines/controlnet/test_controlnet.py | 24 +++ .../test_latent_consistency_models.py | 19 +++ .../test_latent_consistency_models_img2img.py | 19 +++ .../stable_diffusion/test_stable_diffusion.py | 22 +++ .../test_stable_diffusion_adapter.py | 22 +++ .../test_stable_diffusion_img2img.py | 19 +++ .../test_stable_diffusion_inpaint.py | 38 +++++ .../test_stable_diffusion_xl.py | 19 +++ .../test_stable_diffusion_xl_adapter.py | 47 ++++++ .../test_stable_diffusion_xl_img2img.py | 20 +++ .../test_stable_diffusion_xl_inpaint.py | 20 +++ tests/schedulers/test_scheduler_lcm.py | 56 +++++++ 26 files changed, 1109 insertions(+), 53 deletions(-) diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 843e3b8b9410..b5c7aee4b4de 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -73,6 +73,51 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker class AltDiffusionPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin @@ -662,6 +707,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + timesteps: List[int] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -693,6 +739,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -830,8 +880,7 @@ def __call__( image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index b196ac4d3f69..9b5eb1b4c66d 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -109,6 +109,51 @@ def preprocess(image): return image +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker class AltDiffusionImg2ImgPipeline( DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin @@ -714,6 +759,7 @@ def __call__( image: PipelineImageInput = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, + timesteps: List[int] = None, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -751,6 +797,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -873,7 +923,7 @@ def __call__( image = self.image_processor.preprocess(image) # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 41e5e75f68e5..1e19678b221d 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -91,6 +91,51 @@ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class StableDiffusionControlNetPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin ): @@ -812,6 +857,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + timesteps: List[int] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -854,6 +900,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -1059,8 +1109,7 @@ def __call__( assert False # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) self._num_timesteps = len(timesteps) # 6. Prepare latent variables diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index 2d5de69d6e88..0e7bd6e72281 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -53,6 +53,51 @@ def retrieve_latents(encoder_output, generator): raise AttributeError("Could not access latents of provided encoder_output") +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -592,6 +637,7 @@ def __call__( num_inference_steps: int = 4, strength: float = 0.8, original_inference_steps: int = None, + timesteps: List[int] = None, guidance_scale: float = 8.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -623,6 +669,10 @@ def __call__( we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule, following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the scheduler's `original_inference_steps` attribute. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps on the original LCM training/distillation timestep schedule are used. Must be in descending + order. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -728,10 +778,14 @@ def __call__( image = self.image_processor.preprocess(image) # 5. Prepare timesteps - self.scheduler.set_timesteps( - num_inference_steps, device, original_inference_steps=original_inference_steps, strength=strength + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + original_inference_steps=original_inference_steps, + strength=strength, ) - timesteps = self.scheduler.timesteps # 6. Prepare latent variables original_inference_steps = ( diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py index c32538625f01..c8f1d647c15b 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py @@ -61,6 +61,51 @@ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class LatentConsistencyModelPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): @@ -530,6 +575,7 @@ def __call__( width: Optional[int] = None, num_inference_steps: int = 4, original_inference_steps: int = None, + timesteps: List[int] = None, guidance_scale: float = 8.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -561,6 +607,10 @@ def __call__( we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule, following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the scheduler's `original_inference_steps` attribute. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps on the original LCM training/distillation timestep schedule are used. Must be in descending + order. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -668,8 +718,9 @@ def __call__( ) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device, original_inference_steps=original_inference_steps) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, original_inference_steps=original_inference_steps + ) # 5. Prepare latent variable num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index a05abe00f2b1..bf43c043490b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -70,6 +70,50 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class StableDiffusionPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin ): @@ -659,6 +703,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + timesteps: List[int] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -690,6 +735,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -827,8 +876,7 @@ def __call__( image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 029cd2b04839..1bec0807a2e0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -105,6 +105,51 @@ def preprocess(image): return image +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class StableDiffusionImg2ImgPipeline( DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin ): @@ -718,6 +763,7 @@ def __call__( image: PipelineImageInput = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, + timesteps: List[int] = None, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -755,6 +801,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -877,7 +927,7 @@ def __call__( image = self.image_processor.preprocess(image) # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 09e50c60a807..251dfb5676c1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -169,6 +169,51 @@ def retrieve_latents(encoder_output, generator): raise AttributeError("Could not access latents of provided encoder_output") +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class StableDiffusionInpaintPipeline( DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin ): @@ -846,6 +891,7 @@ def __call__( width: Optional[int] = None, strength: float = 1.0, num_inference_steps: int = 50, + timesteps: List[int] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -896,6 +942,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. @@ -1054,7 +1104,7 @@ def __call__( image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps=num_inference_steps, strength=strength, device=device ) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index e32791693012..40c981a46d48 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -100,6 +100,51 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class StableDiffusionXLPipeline( DiffusionPipeline, FromSingleFileMixin, @@ -742,6 +787,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + timesteps: List[int] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -793,6 +839,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will @@ -984,9 +1034,7 @@ def __call__( ) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index d40a037e67fe..dc8b95bf99a6 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -114,6 +114,51 @@ def retrieve_latents(encoder_output, generator): raise AttributeError("Could not access latents of provided encoder_output") +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class StableDiffusionXLImg2ImgPipeline( DiffusionPipeline, TextualInversionLoaderMixin, @@ -877,6 +922,7 @@ def __call__( image: PipelineImageInput = None, strength: float = 0.3, num_inference_steps: int = 50, + timesteps: List[int] = None, denoising_start: Optional[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, @@ -930,6 +976,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. denoising_start (`float`, *optional*): When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and @@ -1137,7 +1187,7 @@ def __call__( def denoising_value_valid(dnv): return isinstance(self.denoising_end, float) and 0 < dnv < 1 - self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 3a9d068d60f3..e49ec0d607d6 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -259,6 +259,51 @@ def retrieve_latents(encoder_output, generator): raise AttributeError("Could not access latents of provided encoder_output") +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class StableDiffusionXLInpaintPipeline( DiffusionPipeline, TextualInversionLoaderMixin, @@ -1101,6 +1146,7 @@ def __call__( width: Optional[int] = None, strength: float = 0.9999, num_inference_steps: int = 50, + timesteps: List[int] = None, denoising_start: Optional[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 7.5, @@ -1171,6 +1217,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. denoising_start (`float`, *optional*): When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and @@ -1376,7 +1426,7 @@ def __call__( def denoising_value_valid(dnv): return isinstance(self.denoising_end, float) and 0 < dnv < 1 - self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 7418e7630f52..a0a17e8cacec 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -118,6 +118,51 @@ def _preprocess_adapter_image(image, height, width): return image +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class StableDiffusionAdapterPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter @@ -660,6 +705,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + timesteps: List[int] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, @@ -694,6 +740,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen @@ -803,8 +853,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 6e3f6a56c100..b07c98fef679 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -123,6 +123,51 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + class StableDiffusionXLAdapterPipeline( DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin ): @@ -721,6 +766,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, + timesteps: List[int] = None, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -774,6 +820,10 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. denoising_end (`float`, *optional*): When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be completed before it is intentionally prematurely terminated. As a result, the returned sample will @@ -957,9 +1007,7 @@ def __call__( ) # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py index c21b556c6ca4..8dd39f261540 100644 --- a/src/diffusers/schedulers/scheduling_lcm.py +++ b/src/diffusers/schedulers/scheduling_lcm.py @@ -247,6 +247,7 @@ def __init__( # setable values self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + self.custom_timesteps = False self._step_index = None @@ -324,17 +325,19 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: def set_timesteps( self, - num_inference_steps: int, + num_inference_steps: Optional[int] = None, device: Union[str, torch.device] = None, original_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, strength: int = 1.0, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. original_inference_steps (`int`, *optional*): @@ -342,16 +345,19 @@ def set_timesteps( schedule (which is different from the standard `diffusers` implementation). We will then take `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep + schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`. """ + # 0. Check inputs + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.") - if num_inference_steps > self.config.num_train_timesteps: - raise ValueError( - f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" - f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" - f" maximal {self.config.num_train_timesteps} timesteps." - ) + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") - self.num_inference_steps = num_inference_steps + # 1. Calculate the LCM original training/distillation timestep schedule. original_steps = ( original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps ) @@ -363,32 +369,95 @@ def set_timesteps( f" maximal {self.config.num_train_timesteps} timesteps." ) - if num_inference_steps > original_steps: - raise ValueError( - f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:" - f" {original_steps} because the final timestep schedule will be a subset of the" - f" `original_inference_steps`-sized initial timestep schedule." - ) - # LCM Timesteps Setting # The skipping step parameter k from the paper. k = self.config.num_train_timesteps // original_steps # LCM Training/Distillation Steps Schedule # Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts). lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1 - skipping_step = len(lcm_origin_timesteps) // num_inference_steps - if skipping_step < 1: - raise ValueError( - f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}." - ) - - # LCM Inference Steps Schedule - lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy() - # Select (approximately) evenly spaced indices from lcm_origin_timesteps. - inference_indices = np.linspace(0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False) - inference_indices = np.floor(inference_indices).astype(np.int64) - timesteps = lcm_origin_timesteps[inference_indices] + # 2. Calculate the LCM inference timestep schedule. + if timesteps is not None: + # 2.1 Handle custom timestep schedules. + train_timesteps = set(lcm_origin_timesteps) + non_train_timesteps = [] + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`custom_timesteps` must be in descending order.") + + if timesteps[i] not in train_timesteps: + non_train_timesteps.append(timesteps[i]) + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1 + if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1: + logger.warning( + f"The first timestep on the custom timestep schedule is {timesteps[0]}, not" + f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get" + f" unexpected results when using this timestep schedule." + ) + + # Raise warning if custom timestep schedule contains timesteps not on original timestep schedule + if non_train_timesteps: + logger.warning( + f"The custom timestep schedule contains the following timesteps which are not on the original" + f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results" + f" when using this timestep schedule." + ) + + # Raise warning if custom timestep schedule is longer than original_steps + if len(timesteps) > original_steps: + logger.warning( + f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the" + f" the length of the timestep schedule used for training: {original_steps}. You may get some" + f" unexpected results when using this timestep schedule." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + self.num_inference_steps = len(timesteps) + self.custom_timesteps = True + + # Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps) + init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps) + t_start = max(self.num_inference_steps - init_timestep, 0) + timesteps = timesteps[t_start * self.order :] + # TODO: also reset self.num_inference_steps? + else: + # 2.2 Create the "standard" LCM inference timestep schedule. + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + skipping_step = len(lcm_origin_timesteps) // num_inference_steps + + if skipping_step < 1: + raise ValueError( + f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}." + ) + + self.num_inference_steps = num_inference_steps + + if num_inference_steps > original_steps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:" + f" {original_steps} because the final timestep schedule will be a subset of the" + f" `original_inference_steps`-sized initial timestep schedule." + ) + + # LCM Inference Steps Schedule + lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy() + # Select (approximately) evenly spaced indices from lcm_origin_timesteps. + inference_indices = np.linspace(0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False) + inference_indices = np.floor(inference_indices).astype(np.int64) + timesteps = lcm_origin_timesteps[inference_indices] self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long) @@ -545,3 +614,19 @@ def get_velocity( def __len__(self): return self.config.num_train_timesteps + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep + def previous_timestep(self, timestep): + if self.custom_timesteps: + index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = torch.tensor(-1) + else: + prev_t = self.timesteps[index + 1] + else: + num_inference_steps = ( + self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps + ) + prev_t = timestep - self.config.num_train_timesteps // num_inference_steps + + return prev_t diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index 1cf52bfeebe2..ce8693343043 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -246,6 +246,30 @@ def test_controlnet_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_controlnet_lcm_custom_timesteps(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionControlNetPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["num_inference_steps"] + inputs["timesteps"] = [999, 499] + output = sd_pipe(**inputs) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [0.52700454, 0.3930534, 0.25509018, 0.7132304, 0.53696585, 0.46568912, 0.7095368, 0.7059624, 0.4744786] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + class StableDiffusionMultiControlNetPipelineFastTests( PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py index d68ef42a25c6..174d9b6de9f8 100644 --- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py +++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py @@ -140,6 +140,25 @@ def test_lcm_multistep(self): expected_slice = np.array([0.1403, 0.5072, 0.5316, 0.1202, 0.3865, 0.4211, 0.5363, 0.3557, 0.3645]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_lcm_custom_timesteps(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components() + pipe = LatentConsistencyModelPipeline(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["num_inference_steps"] + inputs["timesteps"] = [999, 499] + output = pipe(**inputs) + image = output.images + assert image.shape == (1, 64, 64, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.1403, 0.5072, 0.5316, 0.1202, 0.3865, 0.4211, 0.5363, 0.3557, 0.3645]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=5e-4) diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py index 53702925534d..f9410ffe640a 100644 --- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py +++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py @@ -153,6 +153,25 @@ def test_lcm_multistep(self): expected_slice = np.array([0.4150, 0.3719, 0.2479, 0.6333, 0.6024, 0.3778, 0.5036, 0.5420, 0.4678]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_lcm_custom_timesteps(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["num_inference_steps"] + inputs["timesteps"] = [999, 499] + output = pipe(**inputs) + image = output.images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.3994, 0.3471, 0.2540, 0.7030, 0.6193, 0.3645, 0.5777, 0.5850, 0.4965]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=5e-4) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 15c1c4fe6671..28d0d07e6948 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -220,6 +220,28 @@ def test_stable_diffusion_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_lcm_custom_timesteps(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["num_inference_steps"] + inputs["timesteps"] = [999, 499] + output = sd_pipe(**inputs) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.3454, 0.5349, 0.5185, 0.2808, 0.4509, 0.4612, 0.4655, 0.3601, 0.4315]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_prompt_embeds(self): components = self.get_dummy_components() sd_pipe = StableDiffusionPipeline(**components) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py index 2252c8ef8e99..a5e8649f060f 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py @@ -314,6 +314,28 @@ def test_adapter_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_adapter_lcm_custom_timesteps(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionAdapterPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["num_inference_steps"] + inputs["timesteps"] = [999, 499] + output = sd_pipe(**inputs) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4535, 0.5493, 0.4359, 0.5452, 0.6086, 0.4441, 0.5544, 0.501, 0.4859]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase): def get_dummy_components(self, time_cond_proj_dim=None): diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 1a482b38e2ee..fb56d868f1cc 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -207,6 +207,25 @@ def test_stable_diffusion_img2img_default_case_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_stable_diffusion_img2img_default_case_lcm_custom_timesteps(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionImg2ImgPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["num_inference_steps"] + inputs["timesteps"] = [999, 499] + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.5709, 0.4614, 0.4587, 0.5978, 0.5298, 0.6910, 0.6240, 0.5212, 0.5454]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_stable_diffusion_img2img_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index cbe4fb2a0ddf..a69edb869641 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -226,6 +226,25 @@ def test_stable_diffusion_inpaint_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_inpaint_lcm_custom_timesteps(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionInpaintPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["num_inference_steps"] + inputs["timesteps"] = [999, 499] + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4931, 0.5988, 0.4569, 0.5556, 0.6650, 0.5087, 0.5966, 0.5358, 0.5269]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_inpaint_image_tensor(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -420,6 +439,25 @@ def test_stable_diffusion_inpaint_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_inpaint_lcm_custom_timesteps(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionInpaintPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["num_inference_steps"] + inputs["timesteps"] = [999, 499] + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.6240, 0.5355, 0.5649, 0.5378, 0.5374, 0.6242, 0.5132, 0.5347, 0.5396]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_inpaint_2_images(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 8957ebbef5ab..59f0c0151d3a 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -183,6 +183,25 @@ def test_stable_diffusion_xl_euler_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_xl_euler_lcm_custom_timesteps(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionXLPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["num_inference_steps"] + inputs["timesteps"] = [999, 499] + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4917, 0.6555, 0.4348, 0.5219, 0.7324, 0.4855, 0.5168, 0.5447, 0.5156]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_xl_prompt_embeds(self): components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**components) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index daf46000a1e0..f63ee8be1dd0 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -389,6 +389,28 @@ def test_adapter_sdxl_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_adapter_sdxl_lcm_custom_timesteps(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionXLAdapterPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["num_inference_steps"] + inputs["timesteps"] = [999, 499] + output = sd_pipe(**inputs) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5425, 0.5385, 0.4964, 0.5045, 0.6149, 0.4974, 0.5469, 0.5332, 0.5426]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + class StableDiffusionXLMultiAdapterPipelineFastTests( StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase @@ -614,6 +636,31 @@ def test_adapter_sdxl_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_adapter_sdxl_lcm_custom_timesteps(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionXLAdapterPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["num_inference_steps"] + inputs["timesteps"] = [999, 499] + output = sd_pipe(**inputs) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448]) + + debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()] + print(",".join(debug)) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + @slow @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 444f12ecfa9d..7cad3fff0a47 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -227,6 +227,26 @@ def test_stable_diffusion_xl_img2img_euler_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_xl_img2img_euler_lcm_custom_timesteps(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionXLImg2ImgPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.config) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["num_inference_steps"] + inputs["timesteps"] = [999, 499] + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + + expected_slice = np.array([0.5604, 0.4352, 0.4717, 0.5844, 0.5101, 0.6704, 0.6290, 0.5460, 0.5286]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_attention_slicing_forward_pass(self): super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py index 7f7a0d81e5a2..4a2798b3edf4 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py @@ -264,6 +264,26 @@ def test_stable_diffusion_xl_inpaint_euler_lcm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_xl_inpaint_euler_lcm_custom_timesteps(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionXLInpaintPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.config) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + del inputs["num_inference_steps"] + inputs["timesteps"] = [999, 499] + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array([0.6611, 0.5569, 0.5531, 0.5471, 0.5918, 0.6393, 0.5074, 0.5468, 0.5185]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_attention_slicing_forward_pass(self): super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) diff --git a/tests/schedulers/test_scheduler_lcm.py b/tests/schedulers/test_scheduler_lcm.py index 014cdca90479..c2c6530faa11 100644 --- a/tests/schedulers/test_scheduler_lcm.py +++ b/tests/schedulers/test_scheduler_lcm.py @@ -242,3 +242,59 @@ def test_full_loop_multistep(self): # TODO: get expected sum and mean assert abs(result_sum.item() - 197.7616) < 1e-3 assert abs(result_mean.item() - 0.2575) < 1e-3 + + def test_custom_timesteps(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [100, 87, 50, 1, 0] + + scheduler.set_timesteps(timesteps=timesteps) + + scheduler_timesteps = scheduler.timesteps + + for i, timestep in enumerate(scheduler_timesteps): + if i == len(timesteps) - 1: + expected_prev_t = -1 + else: + expected_prev_t = timesteps[i + 1] + + prev_t = scheduler.previous_timestep(timestep) + prev_t = prev_t.item() + + self.assertEqual(prev_t, expected_prev_t) + + def test_custom_timesteps_increasing_order(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [100, 87, 50, 51, 0] + + with self.assertRaises(ValueError, msg="`custom_timesteps` must be in descending order."): + scheduler.set_timesteps(timesteps=timesteps) + + def test_custom_timesteps_passing_both_num_inference_steps_and_timesteps(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [100, 87, 50, 1, 0] + num_inference_steps = len(timesteps) + + with self.assertRaises(ValueError, msg="Can only pass one of `num_inference_steps` or `custom_timesteps`."): + scheduler.set_timesteps(num_inference_steps=num_inference_steps, timesteps=timesteps) + + def test_custom_timesteps_too_large(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + timesteps = [scheduler.config.num_train_timesteps] + + with self.assertRaises( + ValueError, + msg="`timesteps` must start before `self.config.train_timesteps`: {scheduler.config.num_train_timesteps}}", + ): + scheduler.set_timesteps(timesteps=timesteps) From c7bfb8b22a4caaa5a32c503d7ff3f6db61390af0 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 27 Nov 2023 19:43:49 +0800 Subject: [PATCH 12/23] set the model to train state before accelerator prepare (#5099) Signed-off-by: Wang, Yi A --- examples/textual_inversion/textual_inversion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 8e932add92af..7fea4fdb6440 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -761,6 +761,7 @@ def main(): num_cycles=args.lr_num_cycles, ) + text_encoder.train() # Prepare everything with our `accelerator`. text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, lr_scheduler From c079cae3d4792fb4099dd8082407cec71bf695d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20de=20Prado?= Date: Mon, 27 Nov 2023 12:46:26 +0100 Subject: [PATCH 13/23] Avoid computing min() that is expensive when do_normalize is False in the image processor (#5896) Avoid computing min() that is expensive when do_normalize is False Avoid extra computing when do_normalize is False --- src/diffusers/image_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index de60c46eb239..3da5f7014169 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -326,7 +326,7 @@ def preprocess( # expected range [0,1], normalize to [-1,1] do_normalize = self.config.do_normalize - if image.min() < 0 and do_normalize: + if do_normalize and image.min() < 0: warnings.warn( "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", From 07eac4d65a8ec67e7ae971da4431f67095e9db8a Mon Sep 17 00:00:00 2001 From: dg845 <58458699+dg845@users.noreply.github.com> Date: Mon, 27 Nov 2023 04:00:40 -0800 Subject: [PATCH 14/23] Fix LCM Stable Diffusion distillation bug related to parsing unet_time_cond_proj_dim (#5893) * Fix bug related to parsing unet_time_cond_proj_dim. * Fix analogous bug in the SD-XL LCM distillation script. --- .../train_lcm_distill_sd_wds.py | 11 ++++++++++- .../train_lcm_distill_sdxl_wds.py | 14 ++++++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index ec4bf432f03d..4c4ad984fc31 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -657,6 +657,15 @@ def parse_args(): default=0.001, help="The huber loss parameter. Only used if `--loss_type=huber`.", ) + parser.add_argument( + "--unet_time_cond_proj_dim", + type=int, + default=256, + help=( + "The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net" + " does not have `time_cond_proj_dim` set." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1138,7 +1147,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min - w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim) + w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) # Move to U-Net device and dtype w = w.to(device=latents.device, dtype=latents.dtype) diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index 7d2b1e103208..920950d0f6e6 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -677,6 +677,15 @@ def parse_args(): default=0.001, help="The huber loss parameter. Only used if `--loss_type=huber`.", ) + parser.add_argument( + "--unet_time_cond_proj_dim", + type=int, + default=256, + help=( + "The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net" + " does not have `time_cond_proj_dim` set." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1233,6 +1242,7 @@ def compute_embeddings( # 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min + w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) w = w.to(device=latents.device, dtype=latents.dtype) @@ -1243,7 +1253,7 @@ def compute_embeddings( noise_pred = unet( noisy_model_input, start_timesteps, - timestep_cond=None, + timestep_cond=w_embedding, encoder_hidden_states=prompt_embeds.float(), added_cond_kwargs=encoded_text, ).sample @@ -1308,7 +1318,7 @@ def compute_embeddings( target_noise_pred = target_unet( x_prev.float(), timesteps, - timestep_cond=None, + timestep_cond=w_embedding, encoder_hidden_states=prompt_embeds.float(), added_cond_kwargs=encoded_text, ).sample From d3cda804e709e914e074849325f2af96590b9ecf Mon Sep 17 00:00:00 2001 From: ginjia Date: Mon, 27 Nov 2023 20:32:43 +0800 Subject: [PATCH 15/23] add LoRA weights load and fuse support for IPEX pipeline (#5920) add IPEX pipeline LoRA weights loading support --- examples/community/stable_diffusion_ipex.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py index 385227db0b70..6d86248acbe6 100644 --- a/examples/community/stable_diffusion_ipex.py +++ b/examples/community/stable_diffusion_ipex.py @@ -21,7 +21,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from diffusers.configuration_utils import FrozenDict -from diffusers.loaders import TextualInversionLoaderMixin +from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput @@ -62,7 +62,7 @@ """ -class StableDiffusionIPEXPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class StableDiffusionIPEXPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion on IPEX. From d72a24b790d25195d1fe1a1d0c370a6477617c89 Mon Sep 17 00:00:00 2001 From: Chi Date: Mon, 27 Nov 2023 18:04:52 +0530 Subject: [PATCH 16/23] Replace multiple variables with one variable. (#5715) * I added a new doc string to the class. This is more flexible to understanding other developers what are doing and where it's using. * Update src/diffusers/models/unet_2d_blocks.py This changes suggest by maintener. Co-authored-by: Sayak Paul * Update src/diffusers/models/unet_2d_blocks.py Add suggested text Co-authored-by: Sayak Paul * Update unet_2d_blocks.py I changed the Parameter to Args text. * Update unet_2d_blocks.py proper indentation set in this file. * Update unet_2d_blocks.py a little bit of change in the act_fun argument line. * I run the black command to reformat style in the code * Update unet_2d_blocks.py similar doc-string add to have in the original diffusion repository. * I enhanced the code by replacing multiple redundant variables with a single variable, as they all served the same purpose. Additionally, I utilized the get_activation function for improved flexibility in choosing activation functions. * Using as black package to reformated my file * reverte some changes * Remove conv_out_padding variables and using as conv_in_padding * conv_out_padding create and add them into the code. * run black command to solving styling problem * add little bit space between comment and import statement * I am utilizing the ruff library to address the style issues in my Makefile. --------- Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu Co-authored-by: Patrick von Platen --- src/diffusers/models/unet_3d_condition.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index c6710256ef39..3c76b5aa8452 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union @@ -22,6 +23,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging +from .activations import get_activation from .attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, @@ -271,7 +273,7 @@ def __init__( self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) - self.conv_act = nn.SiLU() + self.conv_act = get_activation("silu") else: self.conv_norm_out = None self.conv_act = None From 20f0cbc88ff6bda5cf0cf6dba2ccf7faa3275d9f Mon Sep 17 00:00:00 2001 From: Viktor Grygorchuk <43035191+VicGrygorchyk@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:47:47 +0200 Subject: [PATCH 17/23] fix: error on device for `lpw_stable_diffusion_xl` pipeline if `pipe.enable_sequential_cpu_offload()` enabled (#5885) fix: set device for pipe.enable_sequential_cpu_offload() --- examples/community/lpw_stable_diffusion_xl.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py index cb955a688643..dfe60d9794e1 100644 --- a/examples/community/lpw_stable_diffusion_xl.py +++ b/examples/community/lpw_stable_diffusion_xl.py @@ -250,6 +250,7 @@ def get_weighted_text_embeddings_sdxl( neg_prompt: str = "", neg_prompt_2: str = None, num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, ): """ This function can process long prompt with weights, no length limitation @@ -262,10 +263,13 @@ def get_weighted_text_embeddings_sdxl( neg_prompt (str) neg_prompt_2 (str) num_images_per_prompt (int) + device (torch.device) Returns: prompt_embeds (torch.Tensor) neg_prompt_embeds (torch.Tensor) """ + device = device or pipe._execution_device + if prompt_2: prompt = f"{prompt} {prompt_2}" @@ -330,17 +334,17 @@ def get_weighted_text_embeddings_sdxl( # get prompt embeddings one by one is not working. for i in range(len(prompt_token_groups)): # get positive prompt embeddings with weights - token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device) - weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device) + token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device) + weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device) - token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device) + token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device) # use first text encoder - prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True) + prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True) prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2] # use second text encoder - prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True) + prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True) prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2] pooled_prompt_embeds = prompt_embeds_2[0] @@ -357,16 +361,16 @@ def get_weighted_text_embeddings_sdxl( embeds.append(token_embedding) # get negative prompt embeddings with weights - neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device) - neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device) - neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device) + neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device) + neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device) + neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device) # use first text encoder - neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True) + neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True) neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2] # use second text encoder - neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True) + neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True) neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2] negative_pooled_prompt_embeds = neg_prompt_embeds_2[0] From e550163b9f503bfd3c941cbd0d9f31dcdc313429 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Nov 2023 14:17:47 +0100 Subject: [PATCH 18/23] [Vae] Make sure all vae's work with latent diffusion models (#5880) * add comments to explain the code better * add comments to explain the code better * add comments to explain the code better * add comments to explain the code better * add comments to explain the code better * fix more * fix more * fix more * fix more * fix more * fix more --- src/diffusers/models/autoencoder_asym_kl.py | 3 + src/diffusers/models/autoencoder_tiny.py | 3 + .../models/consistency_decoder_vae.py | 1 + .../pipeline_alt_diffusion_img2img.py | 8 +- .../controlnet/pipeline_controlnet_img2img.py | 8 +- .../controlnet/pipeline_controlnet_inpaint.py | 8 +- .../pipeline_controlnet_inpaint_sd_xl.py | 18 ++- .../pipeline_controlnet_sd_xl_img2img.py | 8 +- .../pipeline_latent_consistency_img2img.py | 8 +- .../pipeline_paint_by_example.py | 8 +- .../pipeline_cycle_diffusion.py | 19 ++- .../pipeline_stable_diffusion_depth2img.py | 8 +- .../pipeline_stable_diffusion_img2img.py | 8 +- .../pipeline_stable_diffusion_inpaint.py | 8 +- ...eline_stable_diffusion_instruct_pix2pix.py | 27 ++-- .../pipeline_stable_diffusion_xl_img2img.py | 8 +- .../pipeline_stable_diffusion_xl_inpaint.py | 8 +- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 27 ++-- .../pipeline_text_to_video_synth_img2img.py | 20 ++- tests/models/test_models_vae.py | 138 ++++++++++-------- tests/pipelines/test_pipelines_common.py | 45 +++++- 21 files changed, 277 insertions(+), 112 deletions(-) diff --git a/src/diffusers/models/autoencoder_asym_kl.py b/src/diffusers/models/autoencoder_asym_kl.py index 656683b43f60..818e181fcdf0 100644 --- a/src/diffusers/models/autoencoder_asym_kl.py +++ b/src/diffusers/models/autoencoder_asym_kl.py @@ -108,6 +108,9 @@ def __init__( self.use_slicing = False self.use_tiling = False + self.register_to_config(block_out_channels=up_block_out_channels) + self.register_to_config(force_upcast=False) + @apply_forward_hook def encode( self, x: torch.FloatTensor, return_dict: bool = True diff --git a/src/diffusers/models/autoencoder_tiny.py b/src/diffusers/models/autoencoder_tiny.py index d2d2f6f9404f..56ccf30e0402 100644 --- a/src/diffusers/models/autoencoder_tiny.py +++ b/src/diffusers/models/autoencoder_tiny.py @@ -148,6 +148,9 @@ def __init__( self.tile_sample_min_size = 512 self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor + self.register_to_config(block_out_channels=decoder_block_out_channels) + self.register_to_config(force_upcast=False) + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: if isinstance(module, (EncoderTiny, DecoderTiny)): module.gradient_checkpointing = value diff --git a/src/diffusers/models/consistency_decoder_vae.py b/src/diffusers/models/consistency_decoder_vae.py index a2d82e2565ed..34176a35e835 100644 --- a/src/diffusers/models/consistency_decoder_vae.py +++ b/src/diffusers/models/consistency_decoder_vae.py @@ -138,6 +138,7 @@ def __init__( ) self.decoder_scheduler = ConsistencyDecoderScheduler() self.register_to_config(block_out_channels=encoder_block_out_channels) + self.register_to_config(force_upcast=False) self.register_buffer( "means", torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None], diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 9b5eb1b4c66d..4272fa124755 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -76,9 +76,13 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 8945bd3d9c81..fa489941c987 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -92,9 +92,13 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 9e2e428eaf91..7bbc4889e7ac 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -104,9 +104,13 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 3e5cba79f50b..0f51ad58a598 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -54,6 +54,20 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -824,12 +838,12 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) if self.vae.config.force_upcast: self.vae.to(dtype) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 4fccd6a91b0f..ba18567b60f7 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -133,9 +133,13 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index 0e7bd6e72281..ed29a939388f 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -44,9 +44,13 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index 38b90b10ad4b..0a20981beb05 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -35,9 +35,13 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index 10adefcff000..e5c2c78720d5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -61,6 +61,20 @@ def preprocess(image): return image +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta): # 1. get previous step value (=t-1) prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps @@ -567,11 +581,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt if isinstance(generator, list): init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 6a712692ac49..e431fee7bdb0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -37,9 +37,13 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 1bec0807a2e0..e3a1a0ed3660 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -73,9 +73,13 @@ """ -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 251dfb5676c1..3570eaa6fd3d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -160,9 +160,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 49da65bfbe9f..d922803858b0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -58,6 +58,20 @@ def preprocess(image): return image +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion). @@ -320,7 +334,6 @@ def __call__( prompt_embeds.dtype, device, self.do_classifier_free_guidance, - generator, ) height, width = image_latents.shape[-2:] @@ -716,17 +729,7 @@ def prepare_image_latents( if image.shape[1] == 4: image_latents = image else: - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if isinstance(generator, list): - image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = self.vae.encode(image).latent_dist.mode() + image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand image_latents for batch_size diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index dc8b95bf99a6..436d816e5eb3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -105,9 +105,13 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index e49ec0d607d6..f54b680dfd7c 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -250,9 +250,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index d639bee39a9f..b14c746f9962 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -88,6 +88,20 @@ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and @@ -533,17 +547,7 @@ def prepare_image_latents( self.upcast_vae() image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if isinstance(generator, list): - image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = self.vae.encode(image).latent_dist.mode() + image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax") # cast back to fp16 if needed if needs_upcasting: @@ -866,7 +870,6 @@ def __call__( prompt_embeds.dtype, device, do_classifier_free_guidance, - generator, ) # 7. Prepare latent variables diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index dae7127c22c1..6779a7b820c2 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -79,6 +79,20 @@ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]: # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 # reshape to ncfhw @@ -547,14 +561,14 @@ def prepare_latents(self, video, timestep, batch_size, dtype, device, generator= f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - elif isinstance(generator, list): init_latents = [ - self.vae.encode(video[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + retrieve_latents(self.vae.encode(video[i : i + 1]), generator=generator[i]) + for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = self.vae.encode(video).latent_dist.sample(generator) + init_latents = retrieve_latents(self.vae.encode(video), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index 842a08c90bf4..83788b836a78 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -46,6 +46,82 @@ enable_full_determinism() +def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None): + block_out_channels = block_out_channels or [32, 64] + norm_num_groups = norm_num_groups or 32 + init_dict = { + "block_out_channels": block_out_channels, + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels), + "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels), + "latent_channels": 4, + "norm_num_groups": norm_num_groups, + } + return init_dict + + +def get_asym_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None): + block_out_channels = block_out_channels or [32, 64] + norm_num_groups = norm_num_groups or 32 + init_dict = { + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels), + "down_block_out_channels": block_out_channels, + "layers_per_down_block": 1, + "up_block_types": ["UpDecoderBlock2D"] * len(block_out_channels), + "up_block_out_channels": block_out_channels, + "layers_per_up_block": 1, + "act_fn": "silu", + "latent_channels": 4, + "norm_num_groups": norm_num_groups, + "sample_size": 32, + "scaling_factor": 0.18215, + } + return init_dict + + +def get_autoencoder_tiny_config(block_out_channels=None): + block_out_channels = (len(block_out_channels) * [32]) if block_out_channels is not None else [32, 32] + init_dict = { + "in_channels": 3, + "out_channels": 3, + "encoder_block_out_channels": block_out_channels, + "decoder_block_out_channels": block_out_channels, + "num_encoder_blocks": [b // min(block_out_channels) for b in block_out_channels], + "num_decoder_blocks": [b // min(block_out_channels) for b in reversed(block_out_channels)], + } + return init_dict + + +def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None): + block_out_channels = block_out_channels or [32, 64] + norm_num_groups = norm_num_groups or 32 + return { + "encoder_block_out_channels": block_out_channels, + "encoder_in_channels": 3, + "encoder_out_channels": 4, + "encoder_down_block_types": ["DownEncoderBlock2D"] * len(block_out_channels), + "decoder_add_attention": False, + "decoder_block_out_channels": block_out_channels, + "decoder_down_block_types": ["ResnetDownsampleBlock2D"] * len(block_out_channels), + "decoder_downsample_padding": 1, + "decoder_in_channels": 7, + "decoder_layers_per_block": 1, + "decoder_norm_eps": 1e-05, + "decoder_norm_num_groups": norm_num_groups, + "encoder_norm_num_groups": norm_num_groups, + "decoder_num_train_timesteps": 1024, + "decoder_out_channels": 6, + "decoder_resnet_time_scale_shift": "scale_shift", + "decoder_time_embedding_type": "learned", + "decoder_up_block_types": ["ResnetUpsampleBlock2D"] * len(block_out_channels), + "scaling_factor": 1, + "latent_channels": 4, + } + + class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = AutoencoderKL main_input_name = "sample" @@ -70,14 +146,7 @@ def output_shape(self): return (3, 32, 32) def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "block_out_channels": [32, 64], - "in_channels": 3, - "out_channels": 3, - "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], - "latent_channels": 4, - } + init_dict = get_autoencoder_kl_config() inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -214,21 +283,7 @@ def output_shape(self): return (3, 32, 32) def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "in_channels": 3, - "out_channels": 3, - "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - "down_block_out_channels": [32, 64], - "layers_per_down_block": 1, - "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], - "up_block_out_channels": [32, 64], - "layers_per_up_block": 1, - "act_fn": "silu", - "latent_channels": 4, - "norm_num_groups": 32, - "sample_size": 32, - "scaling_factor": 0.18215, - } + init_dict = get_asym_autoencoder_kl_config() inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -263,14 +318,7 @@ def output_shape(self): return (3, 32, 32) def prepare_init_args_and_inputs_for_common(self): - init_dict = { - "in_channels": 3, - "out_channels": 3, - "encoder_block_out_channels": (32, 32), - "decoder_block_out_channels": (32, 32), - "num_encoder_blocks": (1, 2), - "num_decoder_blocks": (2, 1), - } + init_dict = get_autoencoder_tiny_config() inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -302,33 +350,7 @@ def output_shape(self): @property def init_dict(self): - return { - "encoder_block_out_channels": [32, 64], - "encoder_in_channels": 3, - "encoder_out_channels": 4, - "encoder_down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - "decoder_add_attention": False, - "decoder_block_out_channels": [32, 64], - "decoder_down_block_types": [ - "ResnetDownsampleBlock2D", - "ResnetDownsampleBlock2D", - ], - "decoder_downsample_padding": 1, - "decoder_in_channels": 7, - "decoder_layers_per_block": 1, - "decoder_norm_eps": 1e-05, - "decoder_norm_num_groups": 32, - "decoder_num_train_timesteps": 1024, - "decoder_out_channels": 6, - "decoder_resnet_time_scale_shift": "scale_shift", - "decoder_time_embedding_type": "learned", - "decoder_up_block_types": [ - "ResnetUpsampleBlock2D", - "ResnetUpsampleBlock2D", - ], - "scaling_factor": 1, - "latent_channels": 4, - } + return get_consistency_vae_config() def prepare_init_args_and_inputs_for_common(self): return self.init_dict, self.inputs_dict() diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index dfe523cda9d4..e11175921184 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -17,7 +17,16 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer import diffusers -from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import ( + AsymmetricAutoencoderKL, + AutoencoderKL, + AutoencoderTiny, + ConsistencyDecoderVAE, + DDIMScheduler, + DiffusionPipeline, + StableDiffusionPipeline, + UNet2DConditionModel, +) from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import logging @@ -28,6 +37,12 @@ torch_device, ) +from ..models.test_models_vae import ( + get_asym_autoencoder_kl_config, + get_autoencoder_kl_config, + get_autoencoder_tiny_config, + get_consistency_vae_config, +) from ..others.test_utils import TOKEN, USER, is_staging_test @@ -171,6 +186,34 @@ def test_latents_input(self): max_diff = np.abs(out - out_latents_inputs).max() self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image") + def test_multi_vae(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + block_out_channels = pipe.vae.config.block_out_channels + norm_num_groups = pipe.vae.config.norm_num_groups + + vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny] + configs = [ + get_autoencoder_kl_config(block_out_channels, norm_num_groups), + get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups), + get_consistency_vae_config(block_out_channels, norm_num_groups), + get_autoencoder_tiny_config(block_out_channels), + ] + + out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0] + + for vae_cls, config in zip(vae_classes, configs): + vae = vae_cls(**config) + vae = vae.to(torch_device) + components["vae"] = vae + vae_pipe = self.pipeline_class(**components) + out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0] + + assert out_vae_np.shape == out_np.shape + @require_torch class PipelineKarrasSchedulerTesterMixin: From ebf581e85f3aad7faa30ceb4678148ee87375446 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Nov 2023 14:18:56 +0100 Subject: [PATCH 19/23] [Tests] Make sure that we don't run tests multiple times (#5949) * [Tests] Make sure that we don't run tests mulitple times * [Tests] Make sure that we don't run tests mulitple times * [Tests] Make sure that we don't run tests mulitple times --- .github/workflows/pr_test_fetcher.yml | 6 +++++- .github/workflows/push_tests_fast.yml | 4 ++++ .github/workflows/push_tests_mps.yml | 4 ++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pr_test_fetcher.yml b/.github/workflows/pr_test_fetcher.yml index d33bca1903f4..7eb208505e75 100644 --- a/.github/workflows/pr_test_fetcher.yml +++ b/.github/workflows/pr_test_fetcher.yml @@ -1,4 +1,4 @@ -name: Fast tests for PRs +name: Fast tests for PRs - Test Fetcher on: pull_request: @@ -14,6 +14,10 @@ env: MKL_NUM_THREADS: 4 PYTEST_TIMEOUT: 60 +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + jobs: setup_pr_tests: name: Setup PR Tests diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml index acd59ef80dc7..798fa777c6c6 100644 --- a/.github/workflows/push_tests_fast.yml +++ b/.github/workflows/push_tests_fast.yml @@ -5,6 +5,10 @@ on: branches: - main +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + env: DIFFUSERS_IS_CI: yes HF_HOME: /mnt/cache diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml index c92aa6426d55..bdea0b760b26 100644 --- a/.github/workflows/push_tests_mps.yml +++ b/.github/workflows/push_tests_mps.yml @@ -13,6 +13,10 @@ env: PYTEST_TIMEOUT: 600 RUN_SLOW: no +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + jobs: run_fast_tests_apple_m1: name: Fast PyTorch MPS tests on MacOS From 14a0d21d2ea2809ca9f88958edc459b5a1c81a16 Mon Sep 17 00:00:00 2001 From: "T. Xu" Date: Mon, 27 Nov 2023 21:29:42 +0800 Subject: [PATCH 20/23] [Community Pipeline] Diffusion Posterior Sampling for General Noisy Inverse Problems (#5939) * [community pipeline] dps impl * add type checking * pass ruff check * ruff formatter --- examples/community/README.md | 144 ++++++++- examples/community/dps_pipeline.py | 466 +++++++++++++++++++++++++++++ 2 files changed, 609 insertions(+), 1 deletion(-) create mode 100755 examples/community/dps_pipeline.py diff --git a/examples/community/README.md b/examples/community/README.md index 96d530412979..e076904a5cfc 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -2480,4 +2480,146 @@ images = pipe( ).images images[0].save("controlnet_and_adapter_inpaint.png") -``` \ No newline at end of file +``` + +## Diffusion Posterior Sampling Pipeline +* Reference paper + ``` + @article{chung2022diffusion, + title={Diffusion posterior sampling for general noisy inverse problems}, + author={Chung, Hyungjin and Kim, Jeongsol and Mccann, Michael T and Klasky, Marc L and Ye, Jong Chul}, + journal={arXiv preprint arXiv:2209.14687}, + year={2022} + } + ``` +* This pipeline allows zero-shot conditional sampling from the posterior distribution $p(x|y)$, given observation on $y$, unconditional generative model $p(x)$ and differentiable operator $y=f(x)$. +* For example, $f(.)$ can be downsample operator, then $y$ is a downsampled image, and the pipeline becomes a super-resolution pipeline. +* To use this pipeline, you need to know your operator $f(.)$ and corrupted image $y$, and pass them during the call. For example, as in the main function of dps_pipeline.py, you need to first define the Gaussian blurring operator $f(.)$. The operator should be a callable nn.Module, with all the parameter gradient disabled: + ```python + import torch.nn.functional as F + import scipy + from torch import nn + + # define the Gaussian blurring operator first + class GaussialBlurOperator(nn.Module): + def __init__(self, kernel_size, intensity): + super().__init__() + + class Blurkernel(nn.Module): + def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0): + super().__init__() + self.blur_type = blur_type + self.kernel_size = kernel_size + self.std = std + self.seq = nn.Sequential( + nn.ReflectionPad2d(self.kernel_size//2), + nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3) + ) + self.weights_init() + + def forward(self, x): + return self.seq(x) + + def weights_init(self): + if self.blur_type == "gaussian": + n = np.zeros((self.kernel_size, self.kernel_size)) + n[self.kernel_size // 2,self.kernel_size // 2] = 1 + k = scipy.ndimage.gaussian_filter(n, sigma=self.std) + k = torch.from_numpy(k) + self.k = k + for name, f in self.named_parameters(): + f.data.copy_(k) + elif self.blur_type == "motion": + k = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix + k = torch.from_numpy(k) + self.k = k + for name, f in self.named_parameters(): + f.data.copy_(k) + + def update_weights(self, k): + if not torch.is_tensor(k): + k = torch.from_numpy(k) + for name, f in self.named_parameters(): + f.data.copy_(k) + + def get_kernel(self): + return self.k + + self.kernel_size = kernel_size + self.conv = Blurkernel(blur_type='gaussian', + kernel_size=kernel_size, + std=intensity) + self.kernel = self.conv.get_kernel() + self.conv.update_weights(self.kernel.type(torch.float32)) + + for param in self.parameters(): + param.requires_grad=False + + def forward(self, data, **kwargs): + return self.conv(data) + + def transpose(self, data, **kwargs): + return data + + def get_kernel(self): + return self.kernel.view(1, 1, self.kernel_size, self.kernel_size) + ``` +* Next, you should obtain the corrupted image $y$ by the operator. In this example, we generate $y$ from the source image $x$. However in practice, having the operator $f(.)$ and corrupted image $y$ is enough: + ```python + # set up source image + src = Image.open('sample.png') + # read image into [1,3,H,W] + src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2,0,1)[None] + # normalize image to [-1,1] + src = (src / 127.5) - 1.0 + src = src.to("cuda") + + # set up operator and measurement + operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to("cuda") + measurement = operator(src) + + # save the source and corrupted images + save_image((src+1.0)/2.0, "dps_src.png") + save_image((measurement+1.0)/2.0, "dps_mea.png") + ``` +* We provide an example pair of saved source and corrupted images, using the Gaussian blur operator above + * Source image: + * ![sample](https://github.com/tongdaxu/Images/assets/22267548/4d2a1216-08d1-4aeb-9ce3-7a2d87561d65) + * Gaussian blurred image: + * ![ddpm_generated_image](https://github.com/tongdaxu/Images/assets/22267548/65076258-344b-4ed8-b704-a04edaade8ae) + * You can download those image to run the example on your own. +* Next, we need to define a loss function used for diffusion posterior sample. For most of the cases, the RMSE is fine: + ```python + def RMSELoss(yhat, y): + return torch.sqrt(torch.sum((yhat-y)**2)) + ``` +* And next, as any other diffusion models, we need the score estimator and scheduler. As we are working with $256x256$ face images, we use ddmp-celebahq-256: + ```python + # set up scheduler + scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256") + scheduler.set_timesteps(1000) + + # set up model + model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda") + ``` +* And finally, run the pipeline: + ```python + # finally, the pipeline + dpspipe = DPSPipeline(model, scheduler) + image = dpspipe( + measurement = measurement, + operator = operator, + loss_fn = RMSELoss, + zeta = 1.0, + ).images[0] + image.save("dps_generated_image.png") + ``` +* The zeta is a hyperparameter that is in range of $[0,1]$. It need to be tuned for best effect. By setting zeta=1, you should be able to have the reconstructed result: + * Reconstructed image: + * ![sample](https://github.com/tongdaxu/Images/assets/22267548/0ceb5575-d42e-4f0b-99c0-50e69c982209) +* The reconstruction is perceptually similar to the source image, but different in details. +* In dps_pipeline.py, we also provide a super-resolution example, which should produce: + * Downsampled image: + * ![dps_mea](https://github.com/tongdaxu/Images/assets/22267548/ff6a33d6-26f0-42aa-88ce-f8a76ba45a13) + * Reconstructed image: + * ![dps_generated_image](https://github.com/tongdaxu/Images/assets/22267548/b74f084d-93f4-4845-83d8-44c0fa758a5f) diff --git a/examples/community/dps_pipeline.py b/examples/community/dps_pipeline.py new file mode 100755 index 000000000000..87919b0f503a --- /dev/null +++ b/examples/community/dps_pipeline.py @@ -0,0 +1,466 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from math import pi +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image + +from diffusers import DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DModel +from diffusers.utils.torch_utils import randn_tensor + + +class DPSPipeline(DiffusionPipeline): + r""" + Pipeline for Diffusion Posterior Sampling. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Parameters: + unet ([`UNet2DModel`]): + A `UNet2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of + [`DDPMScheduler`], or [`DDIMScheduler`]. + """ + + model_cpu_offload_seq = "unet" + + def __init__(self, unet, scheduler): + super().__init__() + self.register_modules(unet=unet, scheduler=scheduler) + + @torch.no_grad() + def __call__( + self, + measurement: torch.Tensor, + operator: torch.nn.Module, + loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], + batch_size: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 1000, + output_type: Optional[str] = "pil", + return_dict: bool = True, + zeta: float = 0.3, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + The call function to the pipeline for generation. + + Args: + measurement (`torch.Tensor`, *required*): + A 'torch.Tensor', the corrupted image + operator (`torch.nn.Module`, *required*): + A 'torch.nn.Module', the operator generating the corrupted image + loss_fn (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *required*): + A 'Callable[[torch.Tensor, torch.Tensor], torch.Tensor]', the loss function used + between the measurements, for most of the cases using RMSE is fine. + batch_size (`int`, *optional*, defaults to 1): + The number of images to generate. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + num_inference_steps (`int`, *optional*, defaults to 1000): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Example: + + ```py + >>> from diffusers import DDPMPipeline + + >>> # load model and scheduler + >>> pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256") + + >>> # run pipeline in inference (sample random noise and denoise) + >>> image = pipe().images[0] + + >>> # save image + >>> image.save("ddpm_generated_image.png") + ``` + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + # Sample gaussian noise to begin loop + if isinstance(self.unet.config.sample_size, int): + image_shape = ( + batch_size, + self.unet.config.in_channels, + self.unet.config.sample_size, + self.unet.config.sample_size, + ) + else: + image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size) + + if self.device.type == "mps": + # randn does not work reproducibly on mps + image = randn_tensor(image_shape, generator=generator) + image = image.to(self.device) + else: + image = randn_tensor(image_shape, generator=generator, device=self.device) + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + with torch.enable_grad(): + # 1. predict noise model_output + image = image.requires_grad_() + model_output = self.unet(image, t).sample + + # 2. compute previous image x'_{t-1} and original prediction x0_{t} + scheduler_out = self.scheduler.step(model_output, t, image, generator=generator) + image_pred, origi_pred = scheduler_out.prev_sample, scheduler_out.pred_original_sample + + # 3. compute y'_t = f(x0_{t}) + measurement_pred = operator(origi_pred) + + # 4. compute loss = d(y, y'_t-1) + loss = loss_fn(measurement, measurement_pred) + loss.backward() + + print("distance: {0:.4f}".format(loss.item())) + + with torch.no_grad(): + image_pred = image_pred - zeta * image.grad + image = image_pred.detach() + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + +if __name__ == "__main__": + import scipy + from torch import nn + from torchvision.utils import save_image + + # defining the operators f(.) of y = f(x) + # super-resolution operator + class SuperResolutionOperator(nn.Module): + def __init__(self, in_shape, scale_factor): + super().__init__() + + # Resizer local class, do not use outiside the SR operator class + class Resizer(nn.Module): + def __init__(self, in_shape, scale_factor=None, output_shape=None, kernel=None, antialiasing=True): + super(Resizer, self).__init__() + + # First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa + scale_factor, output_shape = self.fix_scale_and_size(in_shape, output_shape, scale_factor) + + # Choose interpolation method, each method has the matching kernel size + def cubic(x): + absx = np.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * (absx <= 1) + ( + -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2 + ) * ((1 < absx) & (absx <= 2)) + + def lanczos2(x): + return ( + (np.sin(pi * x) * np.sin(pi * x / 2) + np.finfo(np.float32).eps) + / ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps) + ) * (abs(x) < 2) + + def box(x): + return ((-0.5 <= x) & (x < 0.5)) * 1.0 + + def lanczos3(x): + return ( + (np.sin(pi * x) * np.sin(pi * x / 3) + np.finfo(np.float32).eps) + / ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps) + ) * (abs(x) < 3) + + def linear(x): + return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1)) + + method, kernel_width = { + "cubic": (cubic, 4.0), + "lanczos2": (lanczos2, 4.0), + "lanczos3": (lanczos3, 6.0), + "box": (box, 1.0), + "linear": (linear, 2.0), + None: (cubic, 4.0), # set default interpolation method as cubic + }.get(kernel) + + # Antialiasing is only used when downscaling + antialiasing *= np.any(np.array(scale_factor) < 1) + + # Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient + sorted_dims = np.argsort(np.array(scale_factor)) + self.sorted_dims = [int(dim) for dim in sorted_dims if scale_factor[dim] != 1] + + # Iterate over dimensions to calculate local weights for resizing and resize each time in one direction + field_of_view_list = [] + weights_list = [] + for dim in self.sorted_dims: + # for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the + # weights that multiply the values there to get its result. + weights, field_of_view = self.contributions( + in_shape[dim], output_shape[dim], scale_factor[dim], method, kernel_width, antialiasing + ) + + # convert to torch tensor + weights = torch.tensor(weights.T, dtype=torch.float32) + + # We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for + # tmp_im[field_of_view.T], (bsxfun style) + weights_list.append( + nn.Parameter( + torch.reshape(weights, list(weights.shape) + (len(scale_factor) - 1) * [1]), + requires_grad=False, + ) + ) + field_of_view_list.append( + nn.Parameter( + torch.tensor(field_of_view.T.astype(np.int32), dtype=torch.long), requires_grad=False + ) + ) + + self.field_of_view = nn.ParameterList(field_of_view_list) + self.weights = nn.ParameterList(weights_list) + + def forward(self, in_tensor): + x = in_tensor + + # Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim + for dim, fov, w in zip(self.sorted_dims, self.field_of_view, self.weights): + # To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize + x = torch.transpose(x, dim, 0) + + # This is a bit of a complicated multiplication: x[field_of_view.T] is a tensor of order image_dims+1. + # for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim + # only, this is why it only adds 1 dim to 5the shape). We then multiply, for each pixel, its set of positions with + # the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style: + # matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the + # same number + x = torch.sum(x[fov] * w, dim=0) + + # Finally we swap back the axes to the original order + x = torch.transpose(x, dim, 0) + + return x + + def fix_scale_and_size(self, input_shape, output_shape, scale_factor): + # First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the + # same size as the number of input dimensions) + if scale_factor is not None: + # By default, if scale-factor is a scalar we assume 2d resizing and duplicate it. + if np.isscalar(scale_factor) and len(input_shape) > 1: + scale_factor = [scale_factor, scale_factor] + + # We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales + scale_factor = list(scale_factor) + scale_factor = [1] * (len(input_shape) - len(scale_factor)) + scale_factor + + # Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size + # to all the unspecified dimensions + if output_shape is not None: + output_shape = list(input_shape[len(output_shape) :]) + list(np.uint(np.array(output_shape))) + + # Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is + # sub-optimal, because there can be different scales to the same output-shape. + if scale_factor is None: + scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape) + + # Dealing with missing output-shape. calculating according to scale-factor + if output_shape is None: + output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor))) + + return scale_factor, output_shape + + def contributions(self, in_length, out_length, scale, kernel, kernel_width, antialiasing): + # This function calculates a set of 'filters' and a set of field_of_view that will later on be applied + # such that each position from the field_of_view will be multiplied with a matching filter from the + # 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers + # around it. This is only done for one dimension of the image. + + # When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of + # 1/sf. this means filtering is more 'low-pass filter'. + fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel + kernel_width *= 1.0 / scale if antialiasing else 1.0 + + # These are the coordinates of the output image + out_coordinates = np.arange(1, out_length + 1) + + # since both scale-factor and output size can be provided simulatneously, perserving the center of the image requires shifting + # the output coordinates. the deviation is because out_length doesn't necesary equal in_length*scale. + # to keep the center we need to subtract half of this deivation so that we get equal margins for boths sides and center is preserved. + shifted_out_coordinates = out_coordinates - (out_length - in_length * scale) / 2 + + # These are the matching positions of the output-coordinates on the input image coordinates. + # Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels: + # [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel. + # The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to + # the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big + # one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor). + # So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is + # at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means: + # (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf) + match_coordinates = shifted_out_coordinates / scale + 0.5 * (1 - 1 / scale) + + # This is the left boundary to start multiplying the filter from, it depends on the size of the filter + left_boundary = np.floor(match_coordinates - kernel_width / 2) + + # Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers + # of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them) + expanded_kernel_width = np.ceil(kernel_width) + 2 + + # Determine a set of field_of_view for each each output position, these are the pixels in the input image + # that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the + # vertical dim is the pixels it 'sees' (kernel_size + 2) + field_of_view = np.squeeze( + np.int16(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1) + ) + + # Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the + # vertical dim is a list of weights matching to the pixel in the field of view (that are specified in + # 'field_of_view') + weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1) + + # Normalize weights to sum up to 1. be careful from dividing by 0 + sum_weights = np.sum(weights, axis=1) + sum_weights[sum_weights == 0] = 1.0 + weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1) + + # We use this mirror structure as a trick for reflection padding at the boundaries + mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1)))) + field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])] + + # Get rid of weights and pixel positions that are of zero weight + non_zero_out_pixels = np.nonzero(np.any(weights, axis=0)) + weights = np.squeeze(weights[:, non_zero_out_pixels]) + field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels]) + + # Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size + return weights, field_of_view + + self.down_sample = Resizer(in_shape, 1 / scale_factor) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, data, **kwargs): + return self.down_sample(data) + + # Gaussian blurring operator + class GaussialBlurOperator(nn.Module): + def __init__(self, kernel_size, intensity): + super().__init__() + + class Blurkernel(nn.Module): + def __init__(self, blur_type="gaussian", kernel_size=31, std=3.0): + super().__init__() + self.blur_type = blur_type + self.kernel_size = kernel_size + self.std = std + self.seq = nn.Sequential( + nn.ReflectionPad2d(self.kernel_size // 2), + nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3), + ) + self.weights_init() + + def forward(self, x): + return self.seq(x) + + def weights_init(self): + if self.blur_type == "gaussian": + n = np.zeros((self.kernel_size, self.kernel_size)) + n[self.kernel_size // 2, self.kernel_size // 2] = 1 + k = scipy.ndimage.gaussian_filter(n, sigma=self.std) + k = torch.from_numpy(k) + self.k = k + for name, f in self.named_parameters(): + f.data.copy_(k) + + def update_weights(self, k): + if not torch.is_tensor(k): + k = torch.from_numpy(k) + for name, f in self.named_parameters(): + f.data.copy_(k) + + def get_kernel(self): + return self.k + + self.kernel_size = kernel_size + self.conv = Blurkernel(blur_type="gaussian", kernel_size=kernel_size, std=intensity) + self.kernel = self.conv.get_kernel() + self.conv.update_weights(self.kernel.type(torch.float32)) + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, data, **kwargs): + return self.conv(data) + + def transpose(self, data, **kwargs): + return data + + def get_kernel(self): + return self.kernel.view(1, 1, self.kernel_size, self.kernel_size) + + # assuming the forward process y = f(x) is polluted by Gaussian noise, use l2 norm + def RMSELoss(yhat, y): + return torch.sqrt(torch.sum((yhat - y) ** 2)) + + # set up source image + src = Image.open("sample.png") + # read image into [1,3,H,W] + src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2, 0, 1)[None] + # normalize image to [-1,1] + src = (src / 127.5) - 1.0 + src = src.to("cuda") + + # set up operator and measurement + # operator = SuperResolutionOperator(in_shape=src.shape, scale_factor=4).to("cuda") + operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to("cuda") + measurement = operator(src) + + # set up scheduler + scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256") + scheduler.set_timesteps(1000) + + # set up model + model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda") + + save_image((src + 1.0) / 2.0, "dps_src.png") + save_image((measurement + 1.0) / 2.0, "dps_mea.png") + + # finally, the pipeline + dpspipe = DPSPipeline(model, scheduler) + image = dpspipe( + measurement=measurement, + operator=operator, + loss_fn=RMSELoss, + zeta=1.0, + ).images[0] + + image.save("dps_generated_image.png") From b135b6e905d6ae35a228abc188de11a091da594f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Nov 2023 14:35:19 +0100 Subject: [PATCH 21/23] [From_pretrained] Fix warning (#5948) --- src/diffusers/pipelines/pipeline_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 0208ade020bd..695d961a5d6f 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -259,7 +259,7 @@ def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token, comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] - if set(comp_model_filenames) == set(model_filenames): + if set(model_filenames).issubset(set(comp_model_filenames)): warnings.warn( f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", FutureWarning, From d9075be494ac9796ed096dfb7741c36c33cba813 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 27 Nov 2023 06:52:36 -1000 Subject: [PATCH 22/23] [load_textual_inversion]: allow multiple tokens (#5837) Co-authored-by: yiyixuxu --- src/diffusers/loaders/textual_inversion.py | 16 +++++++- tests/pipelines/test_pipelines.py | 48 ++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/textual_inversion.py b/src/diffusers/loaders/textual_inversion.py index e36f03437a45..d03bd74d5250 100644 --- a/src/diffusers/loaders/textual_inversion.py +++ b/src/diffusers/loaders/textual_inversion.py @@ -189,7 +189,7 @@ def _check_text_inv_inputs(self, tokenizer, text_encoder, pretrained_model_name_ f" `{self.load_textual_inversion.__name__}`" ) - if len(pretrained_model_name_or_paths) != len(tokens): + if len(pretrained_model_name_or_paths) > 1 and len(pretrained_model_name_or_paths) != len(tokens): raise ValueError( f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} " f"Make sure both lists have the same length." @@ -382,7 +382,9 @@ def load_textual_inversion( if not isinstance(pretrained_model_name_or_path, list) else pretrained_model_name_or_path ) - tokens = len(pretrained_model_name_or_paths) * [token] if (isinstance(token, str) or token is None) else token + tokens = [token] if not isinstance(token, list) else token + if tokens[0] is None: + tokens = tokens * len(pretrained_model_name_or_paths) # 3. Check inputs self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens) @@ -390,6 +392,16 @@ def load_textual_inversion( # 4. Load state dicts of textual embeddings state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs) + # 4.1 Handle the special case when state_dict is a tensor that contains n embeddings for n tokens + if len(tokens) > 1 and len(state_dicts) == 1: + if isinstance(state_dicts[0], torch.Tensor): + state_dicts = list(state_dicts[0]) + if len(tokens) != len(state_dicts): + raise ValueError( + f"You have passed a state_dict contains {len(state_dicts)} embeddings, and list of tokens of length {len(tokens)} " + f"Make sure both have the same length." + ) + # 4. Retrieve tokens and embeddings tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index d812ce0ccb95..32ae81ddc2d8 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -792,6 +792,54 @@ def test_text_inversion_download(self): out = pipe(prompt, num_inference_steps=1, output_type="numpy").images assert out.shape == (1, 128, 128, 3) + def test_text_inversion_multi_tokens(self): + pipe1 = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None + ) + pipe1 = pipe1.to(torch_device) + + token1, token2 = "<*>", "<**>" + ten1 = torch.ones((32,)) + ten2 = torch.ones((32,)) * 2 + + num_tokens = len(pipe1.tokenizer) + + pipe1.load_textual_inversion(ten1, token=token1) + pipe1.load_textual_inversion(ten2, token=token2) + emb1 = pipe1.text_encoder.get_input_embeddings().weight + + pipe2 = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None + ) + pipe2 = pipe2.to(torch_device) + pipe2.load_textual_inversion([ten1, ten2], token=[token1, token2]) + emb2 = pipe2.text_encoder.get_input_embeddings().weight + + pipe3 = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None + ) + pipe3 = pipe3.to(torch_device) + pipe3.load_textual_inversion(torch.stack([ten1, ten2], dim=0), token=[token1, token2]) + emb3 = pipe3.text_encoder.get_input_embeddings().weight + + assert len(pipe1.tokenizer) == len(pipe2.tokenizer) == len(pipe3.tokenizer) == num_tokens + 2 + assert ( + pipe1.tokenizer.convert_tokens_to_ids(token1) + == pipe2.tokenizer.convert_tokens_to_ids(token1) + == pipe3.tokenizer.convert_tokens_to_ids(token1) + == num_tokens + ) + assert ( + pipe1.tokenizer.convert_tokens_to_ids(token2) + == pipe2.tokenizer.convert_tokens_to_ids(token2) + == pipe3.tokenizer.convert_tokens_to_ids(token2) + == num_tokens + 1 + ) + assert emb1[num_tokens].sum().item() == emb2[num_tokens].sum().item() == emb3[num_tokens].sum().item() + assert ( + emb1[num_tokens + 1].sum().item() == emb2[num_tokens + 1].sum().item() == emb3[num_tokens + 1].sum().item() + ) + def test_download_ignore_files(self): # Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4 with tempfile.TemporaryDirectory() as tmpdirname: From 50a749e90990932d49556ee54a333278559722f3 Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Mon, 27 Nov 2023 11:50:59 -0800 Subject: [PATCH 23/23] [docs] Fix space (#5898) * fix * minor edits --- .../unconditional_image_generation.md | 49 +++++++------------ 1 file changed, 18 insertions(+), 31 deletions(-) diff --git a/docs/source/en/using-diffusers/unconditional_image_generation.md b/docs/source/en/using-diffusers/unconditional_image_generation.md index 1983f6981e8f..6c55c4edec08 100644 --- a/docs/source/en/using-diffusers/unconditional_image_generation.md +++ b/docs/source/en/using-diffusers/unconditional_image_generation.md @@ -14,54 +14,41 @@ specific language governing permissions and limitations under the License. [[open-in-colab]] -Unconditional image generation is a relatively straightforward task. The model only generates images - without any additional context like text or an image - resembling the training data it was trained on. +Unconditional image generation generates images that look like a random sample from the training data the model was trained on because the denoising process is not guided by any additional context like text or image. -The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference. +To get started, use the [`DiffusionPipeline`] to load the [anton-l/ddpm-butterflies-128](https://huggingface.co/anton-l/ddpm-butterflies-128) checkpoint to generate images of butterflies. The [`DiffusionPipeline`] downloads and caches all the model components required to generate an image. -Start by creating an instance of [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download. -You can use any of the ๐Ÿงจ Diffusers [checkpoints](https://huggingface.co/models?library=diffusers&sort=downloads) from the Hub (the checkpoint you'll use generates images of butterflies). +```py +from diffusers import DiffusionPipeline + +generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128").to("cuda") +image = generator().images[0] +image +``` -๐Ÿ’ก Want to train your own unconditional image generation model? Take a look at the training [guide](../training/unconditional_training) to learn how to generate your own images. +Want to generate images of something else? Take a look at the training [guide](../training/unconditional_training) to learn how to train a model to generate your own images. -In this guide, you'll use [`DiffusionPipeline`] for unconditional image generation with [DDPM](https://arxiv.org/abs/2006.11239): - -```python -from diffusers import DiffusionPipeline - -generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128", use_safetensors=True) -``` +The output image is a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object that can be saved: -The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components. -Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on a GPU. -You can move the generator object to a GPU, just like you would in PyTorch: - -```python -generator.to("cuda") +```py +image.save("generated_image.png") ``` -Now you can use the `generator` to generate an image: +You can also try experimenting with the `num_inference_steps` parameter, which controls the number of denoising steps. More denoising steps typically produce higher quality images, but it'll take longer to generate. Feel free to play around with this parameter to see how it affects the image quality. -```python -image = generator().images[0] +```py +image = generator(num_inference_steps=100).images[0] image ``` -The output is by default wrapped into a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object. - -You can save the image by calling: - -```python -image.save("generated_image.png") -``` - -Try out the Spaces below, and feel free to play around with the inference steps parameter to see how it affects the image quality! +Try out the Space below to generate an image of a butterfly!