diff --git a/ppdiffusers/examples/cogvideo/README.md b/ppdiffusers/examples/cogvideo/README.md new file mode 100644 index 000000000..8e13431d6 --- /dev/null +++ b/ppdiffusers/examples/cogvideo/README.md @@ -0,0 +1,10 @@ +# CogVideoX视频生成 + +```shell +python infer.py \ + --prompt "a bear is walking in a zoon" \ + --model_path THUDM/CogVideoX-2b/ \ + --generate_type "t2v" \ + --dtype "float16" \ + --seed 42 +``` \ No newline at end of file diff --git a/ppdiffusers/examples/cogvideo/infer.py b/ppdiffusers/examples/cogvideo/infer.py new file mode 100644 index 000000000..8f5a0bc1c --- /dev/null +++ b/ppdiffusers/examples/cogvideo/infer.py @@ -0,0 +1,186 @@ +""" +This script demonstrates how to generate a video using the CogVideoX model with the Hugging Face `diffusers` pipeline. +The script supports different types of video generation, including text-to-video (t2v), image-to-video (i2v), +and video-to-video (v2v), depending on the input data and different weight. + +- text-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b +- video-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b +- image-to-video: THUDM/CogVideoX-5b-I2V + +Running the Script: +To run the script, use the following command with appropriate arguments: + +```bash +$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX-5b --generate_type "t2v" +``` + +Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths. +""" + +import argparse +from typing import Literal + +import paddle +from ppdiffusers import ( + CogVideoXPipeline, + CogVideoXDDIMScheduler, + CogVideoXDPMScheduler, + # CogVideoXImageToVideoPipeline, + # CogVideoXVideoToVideoPipeline, +) + +from ppdiffusers.utils import export_to_video_2 + + +def generate_video( + prompt: str, + model_path: str, + lora_path: str = None, + lora_rank: int = 128, + output_path: str = "./output.mp4", + image_or_video_path: str = "", + num_inference_steps: int = 50, + guidance_scale: float = 6.0, + num_videos_per_prompt: int = 1, + dtype: paddle.dtype = paddle.bfloat16, + generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video + seed: int = 42, +): + """ + Generates a video based on the given prompt and saves it to the specified path. + + Parameters: + - prompt (str): The description of the video to be generated. + - model_path (str): The path of the pre-trained model to be used. + - lora_path (str): The path of the LoRA weights to be used. + - lora_rank (int): The rank of the LoRA weights. + - output_path (str): The path where the generated video will be saved. + - num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality. + - guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt. + - num_videos_per_prompt (int): Number of videos to generate per prompt. + - dtype (paddle.dtype): The data type for computation (default is paddle.bfloat16). + - generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v').· + - seed (int): The seed for reproducibility. + """ + + # 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16). + # add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload() + # function to use Multi GPUs. + + image = None + video = None + + if generate_type == "i2v": + # pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, dtype=dtype) + # image = load_image(image=image_or_video_path) + raise NotImplementedError + elif generate_type == "t2v": + pipe = CogVideoXPipeline.from_pretrained(model_path, paddle_dtype=dtype) + else: + # pipe = CogVideoXVideoToVideoPipeline.from_pretrained(model_path, dtype=dtype) + # video = load_video(image_or_video_path) + raise NotImplementedError + + # If you're using with lora, add this code + if lora_path: + pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1") + pipe.fuse_lora(lora_scale=1 / lora_rank) + + # 2. Set Scheduler. + # Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`. + # We recommend using `CogVideoXDDIMScheduler` for CogVideoX-2B. + # using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V. + + pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") + # pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") + + pipe.vae.enable_slicing() + pipe.vae.enable_tiling() + + # 4. Generate the video frames based on the prompt. + # `num_frames` is the Number of frames to generate. + # This is the default value for 6 seconds video and 8 fps and will plus 1 frame for the first frame and 49 frames. + if generate_type == "i2v": + # video_generate = pipe( + # prompt=prompt, + # image=image, # The path of the image to be used as the background of the video + # num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt + # num_inference_steps=num_inference_steps, # Number of inference steps + # num_frames=49, # Number of frames to generate,changed to 49 for diffusers version `0.30.3` and after. + # use_dynamic_cfg=True, # This id used for DPM Sechduler, for DDIM scheduler, it should be False + # guidance_scale=guidance_scale, + # generator=paddle.seed(seed), # Set the seed for reproducibility + # ).frames[0] + raise NotImplementedError + elif generate_type == "t2v": + video_generate = pipe( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + num_inference_steps=num_inference_steps, + num_frames=49, + use_dynamic_cfg=True, + guidance_scale=guidance_scale, + generator=paddle.Generator().manual_seed(seed), + ).frames[0] + else: + # video_generate = pipe( + # prompt=prompt, + # video=video, # The path of the video to be used as the background of the video + # num_videos_per_prompt=num_videos_per_prompt, + # num_inference_steps=num_inference_steps, + # # num_frames=49, + # use_dynamic_cfg=True, + # guidance_scale=guidance_scale, + # generator=paddle.seed(seed), # Set the seed for reproducibility + # ).frames[0] + raise NotImplementedError + # 5. Export the generated frames to a video file. fps must be 8 for original video. + export_to_video_2(video_generate, output_path, fps=8) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX") + parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated") + parser.add_argument( + "--image_or_video_path", + type=str, + default=None, + help="The path of the image to be used as the background of the video", + ) + parser.add_argument( + "--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used" + ) + parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used") + parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights") + parser.add_argument( + "--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved" + ) + parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance") + parser.add_argument( + "--num_inference_steps", type=int, default=50, help="Number of steps for the inference process" + ) + parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt") + parser.add_argument( + "--generate_type", type=str, default="t2v", help="The type of video generation (e.g., 't2v', 'i2v', 'v2v')" + ) + parser.add_argument( + "--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')" + ) + parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility") + + args = parser.parse_args() + dtype = paddle.float16 if args.dtype == "float16" else paddle.bfloat16 + generate_video( + prompt=args.prompt, + model_path=args.model_path, + lora_path=args.lora_path, + lora_rank=args.lora_rank, + output_path=args.output_path, + image_or_video_path=args.image_or_video_path, + num_inference_steps=args.num_inference_steps, + guidance_scale=args.guidance_scale, + num_videos_per_prompt=args.num_videos_per_prompt, + dtype=dtype, + generate_type=args.generate_type, + seed=args.seed, + ) diff --git a/ppdiffusers/ppdiffusers/__init__.py b/ppdiffusers/ppdiffusers/__init__.py index d47fc7a38..5d21f1f1c 100644 --- a/ppdiffusers/ppdiffusers/__init__.py +++ b/ppdiffusers/ppdiffusers/__init__.py @@ -110,8 +110,10 @@ [ "AsymmetricAutoencoderKL", "AutoencoderKL", + "AutoencoderKLCogVideoX", "AutoencoderKLTemporalDecoder", "AutoencoderTiny", + "CogVideoXTransformer3DModel", "ConsistencyDecoderVAE", "ControlNetModel", "Kandinsky3UNet", @@ -182,6 +184,8 @@ _import_structure["schedulers"].extend( [ "CMStochasticIterativeScheduler", + "CogVideoXDDIMScheduler", + "CogVideoXDPMScheduler", "DDIMInverseScheduler", "DDIMParallelScheduler", "DDIMScheduler", @@ -266,6 +270,7 @@ "BlipDiffusionControlNetPipeline", "BlipDiffusionPipeline", "CLIPImageProjection", + "CogVideoXPipeline", "CycleDiffusionPipeline", "IFImg2ImgPipeline", "IFImg2ImgSuperResolutionPipeline", @@ -486,9 +491,11 @@ from .models import ( # new add AsymmetricAutoencoderKL, AutoencoderKL, + AutoencoderKLCogVideoX, AutoencoderKL_imgtovideo, AutoencoderKLTemporalDecoder, AutoencoderTiny, + CogVideoXTransformer3DModel, ConsistencyDecoderVAE, ControlNetModel, DiTLLaMA2DModel, @@ -554,6 +561,8 @@ ) from .schedulers import ( CMStochasticIterativeScheduler, + CogVideoXDDIMScheduler, + CogVideoXDPMScheduler, DDIMInverseScheduler, DDIMParallelScheduler, DDIMScheduler, @@ -619,6 +628,7 @@ AudioLDM2UNet2DConditionModel, AudioLDMPipeline, CLIPImageProjection, + CogVideoXPipeline, CycleDiffusionPipeline, IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, diff --git a/ppdiffusers/ppdiffusers/callbacks.py b/ppdiffusers/ppdiffusers/callbacks.py new file mode 100644 index 000000000..38542407e --- /dev/null +++ b/ppdiffusers/ppdiffusers/callbacks.py @@ -0,0 +1,156 @@ +from typing import Any, Dict, List + +from .configuration_utils import ConfigMixin, register_to_config +from .utils import CONFIG_NAME + + +class PipelineCallback(ConfigMixin): + """ + Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing + custom callbacks and ensures that all callbacks have a consistent interface. + + Please implement the following: + `tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to + include + variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. + `callback_fn`: This method defines the core functionality of your callback. + """ + + config_name = CONFIG_NAME + + @register_to_config + def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None): + super().__init__() + + if (cutoff_step_ratio is None and cutoff_step_index is None) or ( + cutoff_step_ratio is not None and cutoff_step_index is not None + ): + raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.") + + if cutoff_step_ratio is not None and ( + not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0) + ): + raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.") + + @property + def tensor_inputs(self) -> List[str]: + raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}") + + def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]: + raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}") + + def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + return self.callback_fn(pipeline, step_index, timestep, callback_kwargs) + + +class MultiPipelineCallbacks: + """ + This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and + provides a unified interface for calling all of them. + """ + + def __init__(self, callbacks: List[PipelineCallback]): + self.callbacks = callbacks + + @property + def tensor_inputs(self) -> List[str]: + return [input for callback in self.callbacks for input in callback.tensor_inputs] + + def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + """ + Calls all the callbacks in order with the given arguments and returns the final callback_kwargs. + """ + for callback in self.callbacks: + callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs) + + return callback_kwargs + + +class SDCFGCutoffCallback(PipelineCallback): + """ + Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or + `cutoff_step_index`), this callback will disable the CFG. + + Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. + """ + + tensor_inputs = ["prompt_embeds"] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds = callback_kwargs[self.tensor_inputs[0]] + prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. + + pipeline._guidance_scale = 0.0 + + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + return callback_kwargs + + +class SDXLCFGCutoffCallback(PipelineCallback): + """ + Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or + `cutoff_step_index`), this callback will disable the CFG. + + Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. + """ + + tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds = callback_kwargs[self.tensor_inputs[0]] + prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. + + add_text_embeds = callback_kwargs[self.tensor_inputs[1]] + add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens + + add_time_ids = callback_kwargs[self.tensor_inputs[2]] + add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector + + pipeline._guidance_scale = 0.0 + + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + callback_kwargs[self.tensor_inputs[1]] = add_text_embeds + callback_kwargs[self.tensor_inputs[2]] = add_time_ids + return callback_kwargs + + +class IPAdapterScaleCutoffCallback(PipelineCallback): + """ + Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by + `cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`. + + Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step. + """ + + tensor_inputs = [] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + pipeline.set_ip_adapter_scale(0.0) + return callback_kwargs diff --git a/ppdiffusers/ppdiffusers/image_processor.py b/ppdiffusers/ppdiffusers/image_processor.py index 75ac521e8..41b8c3daa 100644 --- a/ppdiffusers/ppdiffusers/image_processor.py +++ b/ppdiffusers/ppdiffusers/image_processor.py @@ -650,3 +650,22 @@ def preprocess( depth = self.binarize(depth) return rgb, depth + + +def is_valid_image(image): + return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, paddle.Tensor)) and image.ndim in (2, 3) + + +def is_valid_image_imagelist(images): + # check if the image input is one of the supported formats for image and image list: + # it can be either one of below 3 + # (1) a 4d pytorch tensor or numpy array, + # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor + # (3) a list of valid image + if isinstance(images, (np.ndarray, paddle.Tensor)) and images.ndim == 4: + return True + elif is_valid_image(images): + return True + elif isinstance(images, list): + return all(is_valid_image(image) for image in images) + return False \ No newline at end of file diff --git a/ppdiffusers/ppdiffusers/models/__init__.py b/ppdiffusers/ppdiffusers/models/__init__.py index 81bea4dcb..ed5187022 100644 --- a/ppdiffusers/ppdiffusers/models/__init__.py +++ b/ppdiffusers/ppdiffusers/models/__init__.py @@ -22,6 +22,7 @@ _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoder_kl"] = ["AutoencoderKL"] + _import_structure["autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] _import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] @@ -32,6 +33,7 @@ _import_structure["t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformer_2d"] = ["Transformer2DModel"] _import_structure["transformer_sd3"] = ["SD3Transformer2DModel"] + _import_structure["cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"] _import_structure["transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["unet_1d"] = ["UNet1DModel"] _import_structure["unet_2d"] = ["UNet2DModel"] @@ -64,6 +66,7 @@ from .adapter import MultiAdapter, T2IAdapter from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_kl import AutoencoderKL + from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_tiny import AutoencoderTiny from .consistency_decoder_vae import ConsistencyDecoderVAE @@ -87,6 +90,7 @@ from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel from .transformer_sd3 import SD3Transformer2DModel + from .cogvideox_transformer_3d import CogVideoXTransformer3DModel from .controlnet_sd3 import SD3ControlNetModel from .controlnet_sd3 import SD3MultiControlNetModel from .transformer_temporal import TransformerTemporalModel diff --git a/ppdiffusers/ppdiffusers/models/activations.py b/ppdiffusers/ppdiffusers/models/activations.py index 44db4a55d..42d56d455 100644 --- a/ppdiffusers/ppdiffusers/models/activations.py +++ b/ppdiffusers/ppdiffusers/models/activations.py @@ -66,9 +66,9 @@ class GELU(nn.Layer): approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. """ - def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool=True): super().__init__() - self.proj = nn.Linear(dim_in, dim_out) + self.proj = nn.Linear(dim_in, dim_out, bias_attr=bias) self.approximate = approximate def gelu(self, gate: paddle.Tensor) -> paddle.Tensor: @@ -89,11 +89,11 @@ class GEGLU(nn.Layer): dim_out (`int`): The number of channels in the output. """ - def __init__(self, dim_in: int, dim_out: int): + def __init__(self, dim_in: int, dim_out: int, bias: bool=True): super().__init__() linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear - self.proj = linear_cls(dim_in, dim_out * 2) + self.proj = linear_cls(dim_in, dim_out * 2, bias_attr=bias) def gelu(self, gate: paddle.Tensor) -> paddle.Tensor: return F.gelu(gate) @@ -114,9 +114,9 @@ class ApproximateGELU(nn.Layer): dim_out (`int`): The number of channels in the output. """ - def __init__(self, dim_in: int, dim_out: int): + def __init__(self, dim_in: int, dim_out: int, bias: bool=True): super().__init__() - self.proj = nn.Linear(dim_in, dim_out) + self.proj = nn.Linear(dim_in, dim_out, bias_attr=bias) def forward(self, x: paddle.Tensor) -> paddle.Tensor: x = self.proj(x) diff --git a/ppdiffusers/ppdiffusers/models/attention.py b/ppdiffusers/ppdiffusers/models/attention.py index 8b5a9d027..ef3a810d0 100644 --- a/ppdiffusers/ppdiffusers/models/attention.py +++ b/ppdiffusers/ppdiffusers/models/attention.py @@ -641,20 +641,23 @@ def __init__( dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, + inner_dim=None, + bias: bool = True, ): super().__init__() - inner_dim = int(dim * mult) + if inner_dim is None: + inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim) + act_fn = GELU(dim, inner_dim, bias=bias) if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh") + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim) + act_fn = GEGLU(dim, inner_dim, bias=bias) elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim) + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) self.net = nn.LayerList([]) # project in @@ -662,7 +665,7 @@ def __init__( # project dropout self.net.append(nn.Dropout(dropout)) # project out - self.net.append(linear_cls(inner_dim, dim_out)) + self.net.append(linear_cls(inner_dim, dim_out, bias_attr=bias)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) diff --git a/ppdiffusers/ppdiffusers/models/attention_processor.py b/ppdiffusers/ppdiffusers/models/attention_processor.py index c93c55ae6..3d6f1659b 100644 --- a/ppdiffusers/ppdiffusers/models/attention_processor.py +++ b/ppdiffusers/ppdiffusers/models/attention_processor.py @@ -91,6 +91,7 @@ def __init__( upcast_softmax: bool = False, cross_attention_norm: Optional[str] = None, cross_attention_norm_num_groups: int = 32, + qk_norm: Optional[str] = None, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, @@ -149,6 +150,15 @@ def __init__( else: self.spatial_norm = None + if qk_norm is None: + self.norm_q = None + self.norm_k = None + elif qk_norm == "layer_norm": + self.norm_q = nn.LayerNorm(dim_head, epsilon=eps) + self.norm_k = nn.LayerNorm(dim_head, epsilon=eps) + else: + raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'") + if cross_attention_norm is None: self.norm_cross = None elif cross_attention_norm == "layer_norm": @@ -2091,6 +2101,156 @@ def __call__( return out +class CogVideoXAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + image_rotary_emb: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + text_seq_length = encoder_hidden_states.shape[1] + + hidden_states = paddle.concat([encoder_hidden_states, hidden_states], axis=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.reshape([batch_size, attn.heads, -1, attention_mask.shape[-1]]) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.reshape([batch_size, -1, attn.heads, head_dim]).transpose([0, 2, 1, 3]) + key = key.reshape([batch_size, -1, attn.heads, head_dim]).transpose([0, 2, 1, 3]) + value = value.reshape([batch_size, -1, attn.heads, head_dim]).transpose([0, 2, 1, 3]) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + # NOTE: There is diff between paddle's and torch's sdpa + # paddle needs input: [batch_size, seq_len, num_heads, head_dim] + # torch needs input: [batch_size, num_heads, seq_len, head_dim] + hidden_states = F.scaled_dot_product_attention( + query.transpose([0, 2, 1, 3]), + key.transpose([0, 2, 1, 3]), + value.transpose([0, 2, 1, 3]), + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False + ) + + hidden_states = hidden_states.reshape([batch_size, -1, attn.heads * head_dim]) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.shape[1] - text_seq_length], axis=1 + ) + return hidden_states, encoder_hidden_states + + +class FusedCogVideoXAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on + query and key vectors, but does not include spatial normalization. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: paddle.Tensor, + encoder_hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + image_rotary_emb: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + text_seq_length = encoder_hidden_states.shape[1] + + hidden_states = paddle.concat([encoder_hidden_states, hidden_states], axis=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.reshape([batch_size, attn.heads, -1, attention_mask.shape[-1]]) + + qkv = attn.to_qkv(hidden_states) + split_size = qkv.shape[-1] // 3 + query, key, value = paddle.split(qkv, split_size, axis=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.reshape([batch_size, -1, attn.heads, head_dim]).premute([0, 2, 1, 3]) + key = key.reshape([batch_size, -1, attn.heads, head_dim]).premute([0, 2, 1, 3]) + value = value.reshape([batch_size, -1, attn.heads, head_dim]).premute([0, 2, 1, 3]) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + from .embeddings import apply_rotary_emb + + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.premute([0, 2, 1, 3]).reshape([batch_size, -1, attn.heads * head_dim]) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.shape[1] - text_seq_length], axis=1 + ) + return hidden_states, encoder_hidden_states + + LoRAAttnProcessor2_5 = LoRAXFormersAttnProcessor AttnAddedKVProcessor2_5 = XFormersAttnAddedKVProcessor AttnProcessor2_5 = XFormersAttnProcessor diff --git a/ppdiffusers/ppdiffusers/models/autoencoder_kl_cogvideox.py b/ppdiffusers/ppdiffusers/models/autoencoder_kl_cogvideox.py new file mode 100644 index 000000000..85a4a3acc --- /dev/null +++ b/ppdiffusers/ppdiffusers/models/autoencoder_kl_cogvideox.py @@ -0,0 +1,1034 @@ +import paddle +from typing import Optional, Tuple, Union +import numpy as np +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging +from ..utils.accelerate_utils import apply_forward_hook +from .activations import get_activation +from .downsampling import CogVideoXDownsample3D +from .modeling_outputs import AutoencoderKLOutput +from .modeling_utils import ModelMixin +from .upsampling import CogVideoXUpsample3D +from .vae import DecoderOutput, DiagonalGaussianDistribution +logger = logging.get_logger(__name__) + + +class CogVideoXSafeConv3d(paddle.nn.Conv3D): + """ + A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model. + """ + + def forward(self, input: paddle.Tensor) ->paddle.Tensor: + memory_count = paddle.prod(x=paddle.to_tensor(data=tuple(input.shape)) + ).item() * 2 / 1024 ** 3 + if memory_count > 2: + kernel_size = self.kernel_size[0] + part_num = int(memory_count / 2) + 1 + input_chunks = paddle.chunk(x=input, chunks=part_num, axis=2) + if kernel_size > 1: + input_chunks = [input_chunks[0]] + [paddle.concat(x=( + input_chunks[i - 1][:, :, -kernel_size + 1:], + input_chunks[i]), axis=2) for i in range(1, len( + input_chunks))] + output_chunks = [] + for input_chunk in input_chunks: + output_chunks.append(super().forward(input_chunk)) + output = paddle.concat(x=output_chunks, axis=2) + return output + else: + return super().forward(input) + + +class CogVideoXCausalConv3d(paddle.nn.Layer): + """A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model. + + Args: + in_channels (`int`): Number of channels in the input tensor. + out_channels (`int`): Number of output channels produced by the convolution. + kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel. + stride (`int`, defaults to `1`): Stride of the convolution. + dilation (`int`, defaults to `1`): Dilation rate of the convolution. + pad_mode (`str`, defaults to `"constant"`): Padding mode. + """ + + def __init__(self, in_channels: int, out_channels: int, kernel_size: + Union[int, Tuple[int, int, int]], stride: int=1, dilation: int=1, + pad_mode: str='constant'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + self.pad_mode = pad_mode + time_pad = dilation * (time_kernel_size - 1) + (1 - stride) + height_pad = height_kernel_size // 2 + width_pad = width_kernel_size // 2 + self.height_pad = height_pad + self.width_pad = width_pad + self.time_pad = time_pad + self.time_causal_padding = (width_pad, width_pad, height_pad, + height_pad, time_pad, 0) + self.temporal_dim = 2 + self.time_kernel_size = time_kernel_size + stride = stride, 1, 1 + dilation = dilation, 1, 1 + self.conv = CogVideoXSafeConv3d(in_channels=in_channels, + out_channels=out_channels, kernel_size=kernel_size, stride= + stride, dilation=dilation) + self.conv_cache = None + + def fake_context_parallel_forward(self, inputs: paddle.Tensor + ) ->paddle.Tensor: + kernel_size = self.time_kernel_size + if kernel_size > 1: + cached_inputs = [self.conv_cache + ] if self.conv_cache is not None else [inputs[:, :, :1]] * ( + kernel_size - 1) + inputs = paddle.concat(x=cached_inputs + [inputs], axis=2) + return inputs + + def _clear_fake_context_parallel_cache(self): + del self.conv_cache + self.conv_cache = None + + def forward(self, inputs: paddle.Tensor) ->paddle.Tensor: + inputs = self.fake_context_parallel_forward(inputs) + self._clear_fake_context_parallel_cache() + self.conv_cache = inputs[:, :, -self.time_kernel_size + 1:].clone() + padding_2d = (self.width_pad, self.width_pad, self.height_pad, self + .height_pad) + inputs = paddle.nn.functional.pad(x=inputs, pad=padding_2d, mode= + 'constant', value=0, pad_from_left_axis=False) + output = self.conv(inputs) + return output + + +class CogVideoXSpatialNorm3D(paddle.nn.Layer): + """ + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific + to 3D-video like data. + + CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model. + + Args: + f_channels (`int`): + The number of channels for input to group normalization layer, and output of the spatial norm layer. + zq_channels (`int`): + The number of channels for the quantized vector as described in the paper. + groups (`int`): + Number of groups to separate the channels into for group normalization. + """ + + def __init__(self, f_channels: int, zq_channels: int, groups: int=32): + super().__init__() + self.norm_layer = paddle.nn.GroupNorm(num_channels=f_channels, + num_groups=groups, epsilon=1e-06, weight_attr=True, bias_attr=True) + self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, + kernel_size=1, stride=1) + self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, + kernel_size=1, stride=1) + + def forward(self, f: paddle.Tensor, zq: paddle.Tensor) ->paddle.Tensor: + if tuple(f.shape)[2] > 1 and tuple(f.shape)[2] % 2 == 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + f_first_size, f_rest_size = tuple(f_first.shape)[-3:], tuple(f_rest + .shape)[-3:] + z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] + z_first = paddle.nn.functional.interpolate(x=z_first, size= + f_first_size) + z_rest = paddle.nn.functional.interpolate(x=z_rest, size= + f_rest_size) + zq = paddle.concat(x=[z_first, z_rest], axis=2) + else: + zq = paddle.nn.functional.interpolate(x=zq, size=tuple(f.shape) + [-3:]) + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + + +class CogVideoXResnetBlock3D(paddle.nn.Layer): + """ + A 3D ResNet block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + spatial_norm_dim (`int`, *optional*): + The dimension to use for spatial norm if it is to be used instead of group norm. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + def __init__(self, in_channels: int, out_channels: Optional[int]=None, + dropout: float=0.0, temb_channels: int=512, groups: int=32, eps: + float=1e-06, non_linearity: str='swish', conv_shortcut: bool=False, + spatial_norm_dim: Optional[int]=None, pad_mode: str='first'): + super().__init__() + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.nonlinearity = get_activation(non_linearity) + self.use_conv_shortcut = conv_shortcut + if spatial_norm_dim is None: + self.norm1 = paddle.nn.GroupNorm(num_channels=in_channels, + num_groups=groups, epsilon=eps) + self.norm2 = paddle.nn.GroupNorm(num_channels=out_channels, + num_groups=groups, epsilon=eps) + else: + self.norm1 = CogVideoXSpatialNorm3D(f_channels=in_channels, + zq_channels=spatial_norm_dim, groups=groups) + self.norm2 = CogVideoXSpatialNorm3D(f_channels=out_channels, + zq_channels=spatial_norm_dim, groups=groups) + self.conv1 = CogVideoXCausalConv3d(in_channels=in_channels, + out_channels=out_channels, kernel_size=3, pad_mode=pad_mode) + if temb_channels > 0: + self.temb_proj = paddle.nn.Linear(in_features=temb_channels, + out_features=out_channels) + self.dropout = paddle.nn.Dropout(p=dropout) + self.conv2 = CogVideoXCausalConv3d(in_channels=out_channels, + out_channels=out_channels, kernel_size=3, pad_mode=pad_mode) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = CogVideoXCausalConv3d(in_channels= + in_channels, out_channels=out_channels, kernel_size=3, + pad_mode=pad_mode) + else: + self.conv_shortcut = CogVideoXSafeConv3d(in_channels= + in_channels, out_channels=out_channels, kernel_size=1, + stride=1, padding=0) + + def forward(self, inputs: paddle.Tensor, temb: Optional[paddle.Tensor]= + None, zq: Optional[paddle.Tensor]=None) ->paddle.Tensor: + hidden_states = inputs + if zq is not None: + hidden_states = self.norm1(hidden_states, zq) + else: + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + if temb is not None: + hidden_states = hidden_states + self.temb_proj(self. + nonlinearity(temb))[:, :, None, None, None] + if zq is not None: + hidden_states = self.norm2(hidden_states, zq) + else: + hidden_states = self.norm2(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + if self.in_channels != self.out_channels: + inputs = self.conv_shortcut(inputs) + hidden_states = hidden_states + inputs + return hidden_states + + +class CogVideoXDownBlock3D(paddle.nn.Layer): + """ + A downsampling block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + add_downsample (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + compress_time (`bool`, defaults to `False`): + Whether or not to downsample across temporal dimension. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + _supports_gradient_checkpointing = True + + def __init__(self, in_channels: int, out_channels: int, temb_channels: + int, dropout: float=0.0, num_layers: int=1, resnet_eps: float=1e-06, + resnet_act_fn: str='swish', resnet_groups: int=32, add_downsample: + bool=True, downsample_padding: int=0, compress_time: bool=False, + pad_mode: str='first'): + super().__init__() + resnets = [] + for i in range(num_layers): + in_channel = in_channels if i == 0 else out_channels + resnets.append(CogVideoXResnetBlock3D(in_channels=in_channel, + out_channels=out_channels, dropout=dropout, temb_channels= + temb_channels, groups=resnet_groups, eps=resnet_eps, + non_linearity=resnet_act_fn, pad_mode=pad_mode)) + self.resnets = paddle.nn.LayerList(sublayers=resnets) + self.downsamplers = None + if add_downsample: + self.downsamplers = paddle.nn.LayerList(sublayers=[ + CogVideoXDownsample3D(out_channels, out_channels, padding= + downsample_padding, compress_time=compress_time)]) + self.gradient_checkpointing = False + + def forward(self, hidden_states: paddle.Tensor, temb: Optional[paddle. + Tensor]=None, zq: Optional[paddle.Tensor]=None) ->paddle.Tensor: + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + + def create_forward(*inputs): + return module(*inputs) + return create_forward + hidden_states = paddle.distributed.fleet.utils.recompute( + create_custom_forward(resnet), hidden_states, temb, zq) + else: + hidden_states = resnet(hidden_states, temb, zq) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + return hidden_states + + +class CogVideoXMidBlock3D(paddle.nn.Layer): + """ + A middle block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + dropout (`float`, defaults to `0.0`): + Dropout rate. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + spatial_norm_dim (`int`, *optional*): + The dimension to use for spatial norm if it is to be used instead of group norm. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + _supports_gradient_checkpointing = True + + def __init__(self, in_channels: int, temb_channels: int, dropout: float + =0.0, num_layers: int=1, resnet_eps: float=1e-06, resnet_act_fn: + str='swish', resnet_groups: int=32, spatial_norm_dim: Optional[int] + =None, pad_mode: str='first'): + super().__init__() + resnets = [] + for _ in range(num_layers): + resnets.append(CogVideoXResnetBlock3D(in_channels=in_channels, + out_channels=in_channels, dropout=dropout, temb_channels= + temb_channels, groups=resnet_groups, eps=resnet_eps, + spatial_norm_dim=spatial_norm_dim, non_linearity= + resnet_act_fn, pad_mode=pad_mode)) + self.resnets = paddle.nn.LayerList(sublayers=resnets) + self.gradient_checkpointing = False + + def forward(self, hidden_states: paddle.Tensor, temb: Optional[paddle. + Tensor]=None, zq: Optional[paddle.Tensor]=None) ->paddle.Tensor: + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + + def create_forward(*inputs): + return module(*inputs) + return create_forward + hidden_states = paddle.distributed.fleet.utils.recompute( + create_custom_forward(resnet), hidden_states, temb, zq) + else: + hidden_states = resnet(hidden_states, temb, zq) + return hidden_states + + +class CogVideoXUpBlock3D(paddle.nn.Layer): + """ + An upsampling block used in the CogVideoX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + temb_channels (`int`, defaults to `512`): + Number of time embedding channels. + dropout (`float`, defaults to `0.0`): + Dropout rate. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + resnet_groups (`int`, defaults to `32`): + Number of groups to separate the channels into for group normalization. + spatial_norm_dim (`int`, defaults to `16`): + The dimension to use for spatial norm if it is to be used instead of group norm. + add_upsample (`bool`, defaults to `True`): + Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension. + compress_time (`bool`, defaults to `False`): + Whether or not to downsample across temporal dimension. + pad_mode (str, defaults to `"first"`): + Padding mode. + """ + + def __init__(self, in_channels: int, out_channels: int, temb_channels: + int, dropout: float=0.0, num_layers: int=1, resnet_eps: float=1e-06, + resnet_act_fn: str='swish', resnet_groups: int=32, spatial_norm_dim: + int=16, add_upsample: bool=True, upsample_padding: int=1, + compress_time: bool=False, pad_mode: str='first'): + super().__init__() + resnets = [] + for i in range(num_layers): + in_channel = in_channels if i == 0 else out_channels + resnets.append(CogVideoXResnetBlock3D(in_channels=in_channel, + out_channels=out_channels, dropout=dropout, temb_channels= + temb_channels, groups=resnet_groups, eps=resnet_eps, + non_linearity=resnet_act_fn, spatial_norm_dim= + spatial_norm_dim, pad_mode=pad_mode)) + self.resnets = paddle.nn.LayerList(sublayers=resnets) + self.upsamplers = None + if add_upsample: + self.upsamplers = paddle.nn.LayerList(sublayers=[ + CogVideoXUpsample3D(out_channels, out_channels, padding= + upsample_padding, compress_time=compress_time)]) + self.gradient_checkpointing = False + + def forward(self, hidden_states: paddle.Tensor, temb: Optional[paddle. + Tensor]=None, zq: Optional[paddle.Tensor]=None) ->paddle.Tensor: + """Forward method of the `CogVideoXUpBlock3D` class.""" + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + + def create_forward(*inputs): + return module(*inputs) + return create_forward + hidden_states = paddle.distributed.fleet.utils.recompute( + create_custom_forward(resnet), hidden_states, temb, zq) + else: + hidden_states = resnet(hidden_states, temb, zq) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + return hidden_states + + +class CogVideoXEncoder3D(paddle.nn.Layer): + """ + The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + """ + _supports_gradient_checkpointing = True + + def __init__(self, in_channels: int=3, out_channels: int=16, + down_block_types: Tuple[str, ...]=('CogVideoXDownBlock3D', + 'CogVideoXDownBlock3D', 'CogVideoXDownBlock3D', + 'CogVideoXDownBlock3D'), block_out_channels: Tuple[int, ...]=(128, + 256, 256, 512), layers_per_block: int=3, act_fn: str='silu', + norm_eps: float=1e-06, norm_num_groups: int=32, dropout: float=0.0, + pad_mode: str='first', temporal_compression_ratio: float=4): + super().__init__() + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + self.conv_in = CogVideoXCausalConv3d(in_channels, + block_out_channels[0], kernel_size=3, pad_mode=pad_mode) + self.down_blocks = paddle.nn.LayerList(sublayers=[]) + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + if down_block_type == 'CogVideoXDownBlock3D': + down_block = CogVideoXDownBlock3D(in_channels=input_channel, + out_channels=output_channel, temb_channels=0, dropout= + dropout, num_layers=layers_per_block, resnet_eps= + norm_eps, resnet_act_fn=act_fn, resnet_groups= + norm_num_groups, add_downsample=not is_final_block, + compress_time=compress_time) + else: + raise ValueError( + 'Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`' + ) + self.down_blocks.append(down_block) + self.mid_block = CogVideoXMidBlock3D(in_channels=block_out_channels + [-1], temb_channels=0, dropout=dropout, num_layers=2, + resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups= + norm_num_groups, pad_mode=pad_mode) + self.norm_out = paddle.nn.GroupNorm(num_groups=norm_num_groups, + num_channels=block_out_channels[-1], epsilon=1e-06) + self.conv_act = paddle.nn.Silu() + self.conv_out = CogVideoXCausalConv3d(block_out_channels[-1], 2 * + out_channels, kernel_size=3, pad_mode=pad_mode) + self.gradient_checkpointing = False + + def forward(self, sample: paddle.Tensor, temb: Optional[paddle.Tensor]=None + ) ->paddle.Tensor: + """The forward method of the `CogVideoXEncoder3D` class.""" + hidden_states = self.conv_in(sample) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + for down_block in self.down_blocks: + hidden_states = paddle.distributed.fleet.utils.recompute( + create_custom_forward(down_block), hidden_states, temb, + None) + hidden_states = paddle.distributed.fleet.utils.recompute( + create_custom_forward(self.mid_block), hidden_states, temb, + None) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states, temb, None) + hidden_states = self.mid_block(hidden_states, temb, None) + hidden_states = self.norm_out(hidden_states) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class CogVideoXDecoder3D(paddle.nn.Layer): + """ + The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output + sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + """ + _supports_gradient_checkpointing = True + + def __init__(self, in_channels: int=16, out_channels: int=3, + up_block_types: Tuple[str, ...]=('CogVideoXUpBlock3D', + 'CogVideoXUpBlock3D', 'CogVideoXUpBlock3D', 'CogVideoXUpBlock3D'), + block_out_channels: Tuple[int, ...]=(128, 256, 256, 512), + layers_per_block: int=3, act_fn: str='silu', norm_eps: float=1e-06, + norm_num_groups: int=32, dropout: float=0.0, pad_mode: str='first', + temporal_compression_ratio: float=4): + super().__init__() + reversed_block_out_channels = list(reversed(block_out_channels)) + self.conv_in = CogVideoXCausalConv3d(in_channels, + reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode) + self.mid_block = CogVideoXMidBlock3D(in_channels= + reversed_block_out_channels[0], temb_channels=0, num_layers=2, + resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups= + norm_num_groups, spatial_norm_dim=in_channels, pad_mode=pad_mode) + self.up_blocks = paddle.nn.LayerList(sublayers=[]) + output_channel = reversed_block_out_channels[0] + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + if up_block_type == 'CogVideoXUpBlock3D': + up_block = CogVideoXUpBlock3D(in_channels= + prev_output_channel, out_channels=output_channel, + temb_channels=0, dropout=dropout, num_layers= + layers_per_block + 1, resnet_eps=norm_eps, + resnet_act_fn=act_fn, resnet_groups=norm_num_groups, + spatial_norm_dim=in_channels, add_upsample=not + is_final_block, compress_time=compress_time, pad_mode= + pad_mode) + prev_output_channel = output_channel + else: + raise ValueError( + 'Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`' + ) + self.up_blocks.append(up_block) + self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[ + -1], in_channels, groups=norm_num_groups) + self.conv_act = paddle.nn.Silu() + self.conv_out = CogVideoXCausalConv3d(reversed_block_out_channels[- + 1], out_channels, kernel_size=3, pad_mode=pad_mode) + self.gradient_checkpointing = False + + def forward(self, sample: paddle.Tensor, temb: Optional[paddle.Tensor]=None + ) ->paddle.Tensor: + """The forward method of the `CogVideoXDecoder3D` class.""" + hidden_states = self.conv_in(sample) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + hidden_states = paddle.distributed.fleet.utils.recompute( + create_custom_forward(self.mid_block), hidden_states, temb, + sample) + for up_block in self.up_blocks: + hidden_states = paddle.distributed.fleet.utils.recompute( + create_custom_forward(up_block), hidden_states, temb, + sample) + else: + hidden_states = self.mid_block(hidden_states, temb, sample) + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states, temb, sample) + hidden_states = self.norm_out(hidden_states, sample) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + return hidden_states + + +class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin): + """ + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [CogVideoX](https://github.com/THUDM/CogVideo). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to `1.15258426`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + """ + _supports_gradient_checkpointing = True + _no_split_modules = ['CogVideoXResnetBlock3D'] + + @register_to_config + def __init__(self, in_channels: int=3, out_channels: int=3, + down_block_types: Tuple[str]=('CogVideoXDownBlock3D', + 'CogVideoXDownBlock3D', 'CogVideoXDownBlock3D', + 'CogVideoXDownBlock3D'), up_block_types: Tuple[str]=( + 'CogVideoXUpBlock3D', 'CogVideoXUpBlock3D', 'CogVideoXUpBlock3D', + 'CogVideoXUpBlock3D'), block_out_channels: Tuple[int]=(128, 256, + 256, 512), latent_channels: int=16, layers_per_block: int=3, act_fn: + str='silu', norm_eps: float=1e-06, norm_num_groups: int=32, + temporal_compression_ratio: float=4, sample_height: int=480, + sample_width: int=720, scaling_factor: float=1.15258426, + shift_factor: Optional[float]=None, latents_mean: Optional[Tuple[ + float]]=None, latents_std: Optional[Tuple[float]]=None, + force_upcast: float=True, use_quant_conv: bool=False, + use_post_quant_conv: bool=False): + super().__init__() + self.encoder = CogVideoXEncoder3D(in_channels=in_channels, + out_channels=latent_channels, down_block_types=down_block_types, + block_out_channels=block_out_channels, layers_per_block= + layers_per_block, act_fn=act_fn, norm_eps=norm_eps, + norm_num_groups=norm_num_groups, temporal_compression_ratio= + temporal_compression_ratio) + self.decoder = CogVideoXDecoder3D(in_channels=latent_channels, + out_channels=out_channels, up_block_types=up_block_types, + block_out_channels=block_out_channels, layers_per_block= + layers_per_block, act_fn=act_fn, norm_eps=norm_eps, + norm_num_groups=norm_num_groups, temporal_compression_ratio= + temporal_compression_ratio) + self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * + out_channels, 1) if use_quant_conv else None + self.post_quant_conv = CogVideoXSafeConv3d(out_channels, + out_channels, 1) if use_post_quant_conv else None + self.use_slicing = False + self.use_tiling = False + self.num_latent_frames_batch_size = 2 + self.num_sample_frames_batch_size = 8 + self.tile_sample_min_height = sample_height // 2 + self.tile_sample_min_width = sample_width // 2 + self.tile_latent_min_height = int(self.tile_sample_min_height / 2 ** + (len(self.config.block_out_channels) - 1)) + self.tile_latent_min_width = int(self.tile_sample_min_width / 2 ** + (len(self.config.block_out_channels) - 1)) + self.tile_overlap_factor_height = 1 / 6 + self.tile_overlap_factor_width = 1 / 5 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)): + module.gradient_checkpointing = value + + def _clear_fake_context_parallel_cache(self): + for name, module in self.named_sublayers(): + if isinstance(module, CogVideoXCausalConv3d): + logger.debug( + f'Clearing fake Context Parallel cache for layer: {name}') + module._clear_fake_context_parallel_cache() + + def enable_tiling(self, tile_sample_min_height: Optional[int]=None, + tile_sample_min_width: Optional[int]=None, + tile_overlap_factor_height: Optional[float]=None, + tile_overlap_factor_width: Optional[float]=None) ->None: + """ + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_overlap_factor_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + tile_overlap_factor_width (`int`, *optional*): + The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there + are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher + value might cause more tiles to be processed leading to slow down of the decoding process. + """ + self.use_tiling = True + self.tile_sample_min_height = (tile_sample_min_height or self. + tile_sample_min_height) + self.tile_sample_min_width = (tile_sample_min_width or self. + tile_sample_min_width) + self.tile_latent_min_height = int(self.tile_sample_min_height / 2 ** + (len(self.config.block_out_channels) - 1)) + self.tile_latent_min_width = int(self.tile_sample_min_width / 2 ** + (len(self.config.block_out_channels) - 1)) + self.tile_overlap_factor_height = (tile_overlap_factor_height or + self.tile_overlap_factor_height) + self.tile_overlap_factor_width = (tile_overlap_factor_width or self + .tile_overlap_factor_width) + + def disable_tiling(self) ->None: + """ + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) ->None: + """ + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) ->None: + """ + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: paddle.Tensor) ->paddle.Tensor: + batch_size, num_channels, num_frames, height, width = tuple(x.shape) + if self.use_tiling and (width > self.tile_sample_min_width or + height > self.tile_sample_min_height): + return self.tiled_encode(x) + frame_batch_size = self.num_sample_frames_batch_size + num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 + enc = [] + for i in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else + remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + x_intermediate = x[:, :, start_frame:end_frame] + x_intermediate = self.encoder(x_intermediate) + if self.quant_conv is not None: + x_intermediate = self.quant_conv(x_intermediate) + enc.append(x_intermediate) + self._clear_fake_context_parallel_cache() + enc = paddle.concat(x=enc, axis=2) + return enc + + @apply_forward_hook + def encode(self, x: paddle.Tensor, return_dict: bool=True) ->Union[ + AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and tuple(x.shape)[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = paddle.concat(x=encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + if not return_dict: + return posterior, + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: paddle.Tensor, return_dict: bool=True) ->Union[ + DecoderOutput, paddle.Tensor]: + batch_size, num_channels, num_frames, height, width = tuple(z.shape) + if self.use_tiling and (width > self.tile_latent_min_width or + height > self.tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + frame_batch_size = self.num_latent_frames_batch_size + num_batches = num_frames // frame_batch_size + dec = [] + for i in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * i + (0 if i == 0 else + remaining_frames) + end_frame = frame_batch_size * (i + 1) + remaining_frames + z_intermediate = z[:, :, start_frame:end_frame] + if self.post_quant_conv is not None: + z_intermediate = self.post_quant_conv(z_intermediate) + z_intermediate = self.decoder(z_intermediate) + dec.append(z_intermediate) + self._clear_fake_context_parallel_cache() + dec = paddle.concat(x=dec, axis=2) + if not return_dict: + return dec, + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: paddle.Tensor, return_dict: bool=True) ->Union[ + DecoderOutput, paddle.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and tuple(z.shape)[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z + .split(1)] + decoded = paddle.concat(x=decoded_slices) + else: + decoded = self._decode(z).sample + if not return_dict: + return decoded, + return DecoderOutput(sample=decoded) + + def blend_v(self, a: paddle.Tensor, b: paddle.Tensor, blend_extent: int + ) ->paddle.Tensor: + blend_extent = min(tuple(a.shape)[3], tuple(b.shape)[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / + blend_extent) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: paddle.Tensor, b: paddle.Tensor, blend_extent: int + ) ->paddle.Tensor: + blend_extent = min(tuple(a.shape)[4], tuple(b.shape)[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / + blend_extent) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x: paddle.Tensor) ->paddle.Tensor: + """Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = tuple(x.shape) + overlap_height = int(self.tile_sample_min_height * (1 - self. + tile_overlap_factor_height)) + overlap_width = int(self.tile_sample_min_width * (1 - self. + tile_overlap_factor_width)) + blend_extent_height = int(self.tile_latent_min_height * self. + tile_overlap_factor_height) + blend_extent_width = int(self.tile_latent_min_width * self. + tile_overlap_factor_width) + row_limit_height = self.tile_latent_min_height - blend_extent_height + row_limit_width = self.tile_latent_min_width - blend_extent_width + frame_batch_size = self.num_sample_frames_batch_size + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + num_batches = (num_frames // frame_batch_size if num_frames > + 1 else 1) + time = [] + for k in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else + remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = x[:, :, start_frame:end_frame, i:i + self. + tile_sample_min_height, j:j + self. + tile_sample_min_width] + tile = self.encoder(tile) + if self.quant_conv is not None: + tile = self.quant_conv(tile) + time.append(tile) + self._clear_fake_context_parallel_cache() + row.append(paddle.concat(x=time, axis=2)) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, + blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, : + row_limit_width]) + result_rows.append(paddle.concat(x=result_row, axis=4)) + enc = paddle.concat(x=result_rows, axis=3) + return enc + + def tiled_decode(self, z: paddle.Tensor, return_dict: bool=True) ->Union[ + DecoderOutput, paddle.Tensor]: + """ + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + batch_size, num_channels, num_frames, height, width = tuple(z.shape) + overlap_height = int(self.tile_latent_min_height * (1 - self. + tile_overlap_factor_height)) + overlap_width = int(self.tile_latent_min_width * (1 - self. + tile_overlap_factor_width)) + blend_extent_height = int(self.tile_sample_min_height * self. + tile_overlap_factor_height) + blend_extent_width = int(self.tile_sample_min_width * self. + tile_overlap_factor_width) + row_limit_height = self.tile_sample_min_height - blend_extent_height + row_limit_width = self.tile_sample_min_width - blend_extent_width + frame_batch_size = self.num_latent_frames_batch_size + rows = [] + for i in range(0, height, overlap_height): + row = [] + for j in range(0, width, overlap_width): + num_batches = num_frames // frame_batch_size + time = [] + for k in range(num_batches): + remaining_frames = num_frames % frame_batch_size + start_frame = frame_batch_size * k + (0 if k == 0 else + remaining_frames) + end_frame = frame_batch_size * (k + 1) + remaining_frames + tile = z[:, :, start_frame:end_frame, i:i + self. + tile_latent_min_height, j:j + self. + tile_latent_min_width] + if self.post_quant_conv is not None: + tile = self.post_quant_conv(tile) + tile = self.decoder(tile) + time.append(tile) + self._clear_fake_context_parallel_cache() + row.append(paddle.concat(x=time, axis=2)) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, + blend_extent_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent_width) + result_row.append(tile[:, :, :, :row_limit_height, : + row_limit_width]) + result_rows.append(paddle.concat(x=result_row, axis=4)) + dec = paddle.concat(x=result_rows, axis=3) + if not return_dict: + return dec, + return DecoderOutput(sample=dec) + + def forward(self, sample: paddle.Tensor, sample_posterior: bool=False, + return_dict: bool=True, generator: Optional[paddle.seed]=None + ) ->Union[paddle.Tensor, paddle.Tensor]: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return dec, + return dec diff --git a/ppdiffusers/ppdiffusers/models/cogvideox_transformer_3d.py b/ppdiffusers/ppdiffusers/models/cogvideox_transformer_3d.py new file mode 100644 index 000000000..b6db5c965 --- /dev/null +++ b/ppdiffusers/ppdiffusers/models/cogvideox_transformer_3d.py @@ -0,0 +1,394 @@ +import paddle +from typing import Any, Dict, Optional, Tuple, Union +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging +from ..utils.paddle_utils import maybe_allow_in_graph +from .attention import Attention, FeedForward +from .attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0 +from .embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps +from .modeling_outputs import Transformer2DModelOutput +from .modeling_utils import ModelMixin +from .normalization import AdaLayerNorm, CogVideoXLayerNormZero +logger = logging.get_logger(__name__) + + +@maybe_allow_in_graph +class CogVideoXBlock(paddle.nn.Layer): + """ + Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model. + + Parameters: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + time_embed_dim (`int`): + The number of channels in timestep embedding. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to be used in feed-forward. + attention_bias (`bool`, defaults to `False`): + Whether or not to use bias in attention projection layers. + qk_norm (`bool`, defaults to `True`): + Whether or not to use normalization after query and key projections in Attention. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, defaults to `1e-5`): + Epsilon value for normalization layers. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*, defaults to `None`): + Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. + ff_bias (`bool`, defaults to `True`): + Whether or not to use bias in Feed-forward layer. + attention_out_bias (`bool`, defaults to `True`): + Whether or not to use bias in Attention output projection layer. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + time_embed_dim: int, + dropout: float=0.0, + activation_fn: str='gelu-approximate', + attention_bias: bool=False, + qk_norm: bool=True, + norm_elementwise_affine: bool=True, + norm_eps: float=1e-05, + final_dropout: bool=True, + ff_inner_dim: Optional[int]=None, + ff_bias: bool=True, + attention_out_bias: bool=True + ): + super().__init__() + + # 1. self attention + self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + + self.attn1 = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm='layer_norm' if qk_norm else None, + eps=1e-06, + bias=attention_bias, + out_bias=attention_out_bias, + processor=CogVideoXAttnProcessor2_0() + ) + + # 2. feed forward + self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias + ) + + def forward( + self, + hidden_states: paddle.Tensor, + encoder_hidden_states:paddle.Tensor, + temb: paddle.Tensor, + image_rotary_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]]=None + ) ->paddle.Tensor: + text_seq_length = encoder_hidden_states.shape[1] + + # norm and modulate + norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( + hidden_states, encoder_hidden_states, temb + ) + + # attention + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb + ) + + hidden_states = hidden_states + gate_msa * attn_hidden_states + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states + + # norm and modulate + norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( + hidden_states, encoder_hidden_states, temb + ) + + # feed forward + norm_hidden_states = paddle.concat([norm_encoder_hidden_states, norm_hidden_states], axis=1) + ff_output = self.ff(norm_hidden_states) + + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = (encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]) + + return hidden_states, encoder_hidden_states + + +class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin): + """ + A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo). + + Parameters: + num_attention_heads (`int`, defaults to `30`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `16`): + The number of channels in the output. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + time_embed_dim (`int`, defaults to `512`): + Output dimension of timestep embeddings. + text_embed_dim (`int`, defaults to `4096`): + Input dimension of text embeddings from the text encoder. + num_layers (`int`, defaults to `30`): + The number of layers of Transformer blocks to use. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + attention_bias (`bool`, defaults to `True`): + Whether or not to use bias in the attention projection layers. + sample_width (`int`, defaults to `90`): + The width of the input latents. + sample_height (`int`, defaults to `60`): + The height of the input latents. + sample_frames (`int`, defaults to `49`): + The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49 + instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings, + but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with + K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1). + patch_size (`int`, defaults to `2`): + The size of the patches to use in the patch embedding layer. + temporal_compression_ratio (`int`, defaults to `4`): + The compression ratio across the temporal dimension. See documentation for `sample_frames`. + max_text_seq_length (`int`, defaults to `226`): + The maximum sequence length of the input text embeddings. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + timestep_activation_fn (`str`, defaults to `"silu"`): + Activation function to use when generating the timestep embeddings. + norm_elementwise_affine (`bool`, defaults to `True`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, defaults to `1e-5`): + The epsilon value to use in normalization layers. + spatial_interpolation_scale (`float`, defaults to `1.875`): + Scaling factor to apply in 3D positional embeddings across spatial dimensions. + temporal_interpolation_scale (`float`, defaults to `1.0`): + Scaling factor to apply in 3D positional embeddings across temporal dimensions. + """ + _supports_gradient_checkpointing = True + + @register_to_config + def __init__(self, num_attention_heads: int=30, attention_head_dim: int + =64, in_channels: int=16, out_channels: Optional[int]=16, + flip_sin_to_cos: bool=True, freq_shift: int=0, time_embed_dim: int= + 512, text_embed_dim: int=4096, num_layers: int=30, dropout: float= + 0.0, attention_bias: bool=True, sample_width: int=90, sample_height: + int=60, sample_frames: int=49, patch_size: int=2, + temporal_compression_ratio: int=4, max_text_seq_length: int=226, + activation_fn: str='gelu-approximate', timestep_activation_fn: str= + 'silu', norm_elementwise_affine: bool=True, norm_eps: float=1e-05, + spatial_interpolation_scale: float=1.875, + temporal_interpolation_scale: float=1.0, + use_rotary_positional_embeddings: bool=False, + use_learned_positional_embeddings: bool=False): + super().__init__() + inner_dim = num_attention_heads * attention_head_dim + if (not use_rotary_positional_embeddings and + use_learned_positional_embeddings): + raise ValueError( + "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional embeddings. If you're using a custom model and/or believe this should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." + ) + self.patch_embed = CogVideoXPatchEmbed(patch_size=patch_size, + in_channels=in_channels, embed_dim=inner_dim, text_embed_dim= + text_embed_dim, bias=True, sample_width=sample_width, + sample_height=sample_height, sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings + ) + self.embedding_dropout = paddle.nn.Dropout(p=dropout) + self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, + timestep_activation_fn) + self.transformer_blocks = paddle.nn.LayerList(sublayers=[ + CogVideoXBlock(dim=inner_dim, num_attention_heads= + num_attention_heads, attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, dropout=dropout, activation_fn= + activation_fn, attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, norm_eps= + norm_eps) for _ in range(num_layers)]) + self.norm_final = paddle.nn.LayerNorm(normalized_shape=inner_dim, + epsilon=norm_eps, weight_attr=norm_elementwise_affine, + bias_attr=norm_elementwise_affine) + self.norm_out = AdaLayerNorm(embedding_dim=time_embed_dim, + output_dim=2 * inner_dim, norm_elementwise_affine= + norm_elementwise_affine, norm_eps=norm_eps, chunk_dim=1) + self.proj_out = paddle.nn.Linear(in_features=inner_dim, + out_features=patch_size * patch_size * out_channels) + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + @property + def attn_processors(self) ->Dict[str, AttentionProcessor]: + """ + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + processors = {} + + def fn_recursive_add_processors(name: str, module: paddle.nn.Layer, + processors: Dict[str, AttentionProcessor]): + if hasattr(module, 'get_processor'): + processors[f'{name}.processor'] = module.get_processor() + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f'{name}.{sub_name}', child, + processors) + return processors + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[ + str, AttentionProcessor]]): + """ + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f'A dict of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}. Please make sure to pass {count} processor classes.' + ) + + def fn_recursive_attn_processor(name: str, module: paddle.nn.Layer, + processor): + if hasattr(module, 'set_processor'): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f'{name}.processor')) + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f'{name}.{sub_name}', child, + processor) + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + for _, attn_processor in self.attn_processors.items(): + if 'Added' in str(attn_processor.__class__.__name__): + raise ValueError( + '`fuse_qkv_projections()` is not supported for models having added KV projections.' + ) + self.original_attn_processors = self.attn_processors + for module in self.sublayers(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + self.set_attn_processor(FusedCogVideoXAttnProcessor2_0()) + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def forward( + self, + hidden_states: paddle.Tensor, + encoder_hidden_states: paddle.Tensor, + timestep: Union[int, float, paddle.Tensor], + timestep_cond: Optional[paddle.Tensor]=None, + image_rotary_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]]=None, + return_dict: bool=True + ): + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + t_emb = t_emb.cast(hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) + hidden_states = self.embedding_dropout(hidden_states) + + text_seq_length = tuple(encoder_hidden_states.shape)[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + # 3. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + raise NotImplementedError + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb + ) + # print("hidden_states:", hidden_states.abs().mean().item(), hidden_states.min().item(), hidden_states.max().item()) + # print("encoder_hidden_states:", encoder_hidden_states.abs().mean().item(), encoder_hidden_states.min().item(), encoder_hidden_states.max().item()) + + if not self.config.use_rotary_positional_embeddings: + # 2B + hidden_states = self.norm_final(hidden_states) + else: + # 5B + hidden_states = paddle.concat(x=[encoder_hidden_states, + hidden_states], axis=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + p = self.config.patch_size + output = hidden_states.reshape([batch_size, num_frames, height // p, width // p, -1, p, p]) + output = output.transpose(perm=[0, 1, 4, 2, 5, 3, 6]).flatten(5, 6).flatten(3, 4) + if not return_dict: + return output, + return Transformer2DModelOutput(sample=output) diff --git a/ppdiffusers/ppdiffusers/models/downsampling.py b/ppdiffusers/ppdiffusers/models/downsampling.py new file mode 100644 index 000000000..b99623ab1 --- /dev/null +++ b/ppdiffusers/ppdiffusers/models/downsampling.py @@ -0,0 +1,326 @@ +import paddle +from typing import Optional, Tuple +from .normalization import RMSNorm +from .upsampling import upfirdn2d_native + + +class Downsample1D(paddle.nn.Layer): + """A 1D downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. + name (`str`, default `conv`): + name of the downsampling 1D layer. + """ + + def __init__(self, channels: int, use_conv: bool=False, out_channels: + Optional[int]=None, padding: int=1, name: str='conv'): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + if use_conv: + self.conv = paddle.nn.Conv1D(in_channels=self.channels, + out_channels=self.out_channels, kernel_size=3, stride= + stride, padding=padding) + else: + assert self.channels == self.out_channels + self.conv = paddle.nn.AvgPool1D(kernel_size=stride, stride= + stride, exclusive=False) + + def forward(self, inputs: paddle.Tensor) ->paddle.Tensor: + assert tuple(inputs.shape)[1] == self.channels + return self.conv(inputs) + + +class Downsample2D(paddle.nn.Layer): + """A 2D downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. + name (`str`, default `conv`): + name of the downsampling 2D layer. + """ + + def __init__(self, channels: int, use_conv: bool=False, out_channels: + Optional[int]=None, padding: int=1, name: str='conv', kernel_size=3, + norm_type=None, eps=None, elementwise_affine=None, bias=True): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + if norm_type == 'ln_norm': + self.norm = paddle.nn.LayerNorm(normalized_shape=channels, + epsilon=eps, weight_attr=elementwise_affine, bias_attr= + elementwise_affine) + elif norm_type == 'rms_norm': + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f'unknown norm_type: {norm_type}') + if use_conv: + conv = paddle.nn.Conv2D(in_channels=self.channels, out_channels + =self.out_channels, kernel_size=kernel_size, stride=stride, + padding=padding, bias_attr=bias) + else: + assert self.channels == self.out_channels + conv = paddle.nn.AvgPool2D(kernel_size=stride, stride=stride, + exclusive=False) + if name == 'conv': + self.Conv2d_0 = conv + self.conv = conv + elif name == 'Conv2d_0': + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states: paddle.Tensor, *args, **kwargs + ) ->paddle.Tensor: + if len(args) > 0 or kwargs.get('scale', None) is not None: + deprecation_message = ( + 'The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`.' + ) + print('scale', '1.0.0', deprecation_message) + assert tuple(hidden_states.shape)[1] == self.channels + if self.norm is not None: + hidden_states = self.norm(hidden_states.transpose(perm=[0, 2, 3, + 1])).transpose(perm=[0, 3, 1, 2]) + if self.use_conv and self.padding == 0: + pad = 0, 1, 0, 1 + hidden_states = paddle.nn.functional.pad(x=hidden_states, pad= + pad, mode='constant', value=0, pad_from_left_axis=False) + assert tuple(hidden_states.shape)[1] == self.channels + hidden_states = self.conv(hidden_states) + return hidden_states + + +class FirDownsample2D(paddle.nn.Layer): + """A 2D FIR downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + fir_kernel (`tuple`, default `(1, 3, 3, 1)`): + kernel for the FIR filter. + """ + + def __init__(self, channels: Optional[int]=None, out_channels: Optional + [int]=None, use_conv: bool=False, fir_kernel: Tuple[int, int, int, + int]=(1, 3, 3, 1)): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = paddle.nn.Conv2D(in_channels=channels, + out_channels=out_channels, kernel_size=3, stride=1, padding=1) + self.fir_kernel = fir_kernel + self.use_conv = use_conv + self.out_channels = out_channels + + def _downsample_2d(self, hidden_states: paddle.Tensor, weight: Optional + [paddle.Tensor]=None, kernel: Optional[paddle.Tensor]=None, factor: + int=2, gain: float=1) ->paddle.Tensor: + """Fused `Conv2d()` followed by `downsample_2d()`. + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. + + Args: + hidden_states (`torch.Tensor`): + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight (`torch.Tensor`, *optional*): + Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be + performed by `inChannels = x.shape[0] // numGroups`. + kernel (`torch.Tensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to average pooling. + factor (`int`, *optional*, default to `2`): + Integer downsampling factor. + gain (`float`, *optional*, default to `1.0`): + Scaling factor for signal magnitude. + + Returns: + output (`torch.Tensor`): + Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same + datatype as `x`. + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + kernel = paddle.to_tensor(data=kernel, dtype='float32') + if kernel.ndim == 1: + kernel = paddle.outer(x=kernel, y=kernel) + kernel /= paddle.sum(x=kernel) + kernel = kernel * gain + if self.use_conv: + _, _, convH, convW = tuple(weight.shape) + pad_value = tuple(kernel.shape)[0] - factor + (convW - 1) + stride_value = [factor, factor] + upfirdn_input = upfirdn2d_native(hidden_states, paddle. + to_tensor(data=kernel, place=hidden_states.place), pad=(( + pad_value + 1) // 2, pad_value // 2)) + output = paddle.nn.functional.conv2d(x=upfirdn_input, weight= + weight, stride=stride_value, padding=0) + else: + pad_value = tuple(kernel.shape)[0] - factor + output = upfirdn2d_native(hidden_states, paddle.to_tensor(data= + kernel, place=hidden_states.place), down=factor, pad=(( + pad_value + 1) // 2, pad_value // 2)) + return output + + def forward(self, hidden_states: paddle.Tensor) ->paddle.Tensor: + if self.use_conv: + downsample_input = self._downsample_2d(hidden_states, weight= + self.Conv2d_0.weight, kernel=self.fir_kernel) + hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, + -1, 1, 1) + else: + hidden_states = self._downsample_2d(hidden_states, kernel=self. + fir_kernel, factor=2) + return hidden_states + + +class KDownsample2D(paddle.nn.Layer): + """A 2D K-downsampling layer. + + Parameters: + pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. + """ + + def __init__(self, pad_mode: str='reflect'): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = paddle.to_tensor(data=[[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) + self.pad = tuple(kernel_1d.shape)[1] // 2 - 1 + self.register_buffer(name='kernel', tensor=kernel_1d.T @ kernel_1d, + persistable=False) + + def forward(self, inputs: paddle.Tensor) ->paddle.Tensor: + inputs = paddle.nn.functional.pad(x=inputs, pad=(self.pad,) * 4, + mode=self.pad_mode, pad_from_left_axis=False) + weight = paddle.zeros(shape=[tuple(inputs.shape)[1], tuple(inputs. + shape)[1], tuple(self.kernel.shape)[0], tuple(self.kernel.shape + )[1]], dtype=inputs.dtype) + indices = paddle.arange(end=tuple(inputs.shape)[1]) + kernel = self.kernel.to(weight)[None, :].expand(shape=[tuple(inputs + .shape)[1], -1, -1]) + weight[indices, indices] = kernel + return paddle.nn.functional.conv2d(x=inputs, weight=weight, stride=2) + + +class CogVideoXDownsample3D(paddle.nn.Layer): + """ + A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI + + Args: + in_channels (`int`): + Number of channels in the input image. + out_channels (`int`): + Number of channels produced by the convolution. + kernel_size (`int`, defaults to `3`): + Size of the convolving kernel. + stride (`int`, defaults to `2`): + Stride of the convolution. + padding (`int`, defaults to `0`): + Padding added to all four sides of the input. + compress_time (`bool`, defaults to `False`): + Whether or not to compress the time dimension. + """ + + def __init__(self, in_channels: int, out_channels: int, kernel_size: + int=3, stride: int=2, padding: int=0, compress_time: bool=False): + super().__init__() + self.conv = paddle.nn.Conv2D(in_channels=in_channels, out_channels= + out_channels, kernel_size=kernel_size, stride=stride, padding= + padding) + self.compress_time = compress_time + + def forward(self, x: paddle.Tensor) ->paddle.Tensor: + if self.compress_time: + batch_size, channels, frames, height, width = tuple(x.shape) + x = x.transpose(perm=[0, 3, 4, 1, 2]).reshape(batch_size * + height * width, channels, frames) + if tuple(x.shape)[-1] % 2 == 1: + x_first, x_rest = x[..., 0], x[..., 1:] + if tuple(x_rest.shape)[-1] > 0: + x_rest = paddle.nn.functional.avg_pool1d(kernel_size=2, + stride=2, x=x_rest, exclusive=False) + x = paddle.concat(x=[x_first[..., None], x_rest], axis=-1) + x = x.reshape(batch_size, height, width, channels, tuple(x. + shape)[-1]).transpose(perm=[0, 3, 4, 1, 2]) + else: + x = paddle.nn.functional.avg_pool1d(kernel_size=2, stride=2, + x=x, exclusive=False) + x = x.reshape(batch_size, height, width, channels, tuple(x. + shape)[-1]).transpose(perm=[0, 3, 4, 1, 2]) + pad = 0, 1, 0, 1 + x = paddle.nn.functional.pad(x=x, pad=pad, mode='constant', value=0, + pad_from_left_axis=False) + batch_size, channels, frames, height, width = tuple(x.shape) + x = x.transpose(perm=[0, 2, 1, 3, 4]).reshape(batch_size * frames, + channels, height, width) + x = self.conv(x) + x = x.reshape(batch_size, frames, tuple(x.shape)[1], tuple(x.shape) + [2], tuple(x.shape)[3]).transpose(perm=[0, 2, 1, 3, 4]) + return x + + +def downsample_2d(hidden_states: paddle.Tensor, kernel: Optional[paddle. + Tensor]=None, factor: int=2, gain: float=1) ->paddle.Tensor: + """Downsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the + given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the + specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its + shape is a multiple of the downsampling factor. + + Args: + hidden_states (`torch.Tensor`) + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel (`torch.Tensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to average pooling. + factor (`int`, *optional*, default to `2`): + Integer downsampling factor. + gain (`float`, *optional*, default to `1.0`): + Scaling factor for signal magnitude. + + Returns: + output (`torch.Tensor`): + Tensor of the shape `[N, C, H // factor, W // factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + kernel = paddle.to_tensor(data=kernel, dtype='float32') + if kernel.ndim == 1: + kernel = paddle.outer(x=kernel, y=kernel) + kernel /= paddle.sum(x=kernel) + kernel = kernel * gain + pad_value = tuple(kernel.shape)[0] - factor + output = upfirdn2d_native(hidden_states, kernel.to(device=hidden_states + .place), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)) + return output diff --git a/ppdiffusers/ppdiffusers/models/embeddings.py b/ppdiffusers/ppdiffusers/models/embeddings.py index d339a3922..f80be4f16 100644 --- a/ppdiffusers/ppdiffusers/models/embeddings.py +++ b/ppdiffusers/ppdiffusers/models/embeddings.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Optional +from typing import List, Optional, Tuple, Union import numpy as np import paddle @@ -913,3 +913,225 @@ def forward(self, caption): hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states + + +def get_3d_sincos_pos_embed(embed_dim: int, spatial_size: Union[int, Tuple[ + int, int]], temporal_size: int, spatial_interpolation_scale: float=1.0, + temporal_interpolation_scale: float=1.0) ->np.ndarray: + """ + Args: + embed_dim (`int`): + spatial_size (`int` or `Tuple[int, int]`): + temporal_size (`int`): + spatial_interpolation_scale (`float`, defaults to 1.0): + temporal_interpolation_scale (`float`, defaults to 1.0): + """ + if embed_dim % 4 != 0: + raise ValueError('`embed_dim` must be divisible by 4') + if isinstance(spatial_size, int): + spatial_size = spatial_size, spatial_size + embed_dim_spatial = 3 * embed_dim // 4 + embed_dim_temporal = embed_dim // 4 + grid_h = np.arange(spatial_size[1], dtype=np.float32 + ) / spatial_interpolation_scale + grid_w = np.arange(spatial_size[0], dtype=np.float32 + ) / spatial_interpolation_scale + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0) + grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, + grid) + grid_t = np.arange(temporal_size, dtype=np.float32 + ) / temporal_interpolation_scale + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, + grid_t) + pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] + pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) + pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] + pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * + spatial_size[1], axis=1) + pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1 + ) + return pos_embed + +class CogVideoXPatchEmbed(paddle.nn.Layer): + + def __init__( + self, + patch_size: int=2, + in_channels: int=16, + embed_dim: int=1920, + text_embed_dim: int=4096, + bias: bool=True, + sample_width: int=90, + sample_height: int=60, + sample_frames: int=49, + temporal_compression_ratio: int=4, + max_text_seq_length: int=226, + spatial_interpolation_scale: float=1.875, + temporal_interpolation_scale: float=1.0, + use_positional_embeddings: bool=True, + use_learned_positional_embeddings: bool=True + ) ->None: + super().__init__() + self.patch_size = patch_size + self.embed_dim = embed_dim + self.sample_height = sample_height + self.sample_width = sample_width + self.sample_frames = sample_frames + self.temporal_compression_ratio = temporal_compression_ratio + self.max_text_seq_length = max_text_seq_length + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.use_positional_embeddings = use_positional_embeddings + self.use_learned_positional_embeddings = use_learned_positional_embeddings + + self.proj = paddle.nn.Conv2D(in_channels=in_channels, out_channels= + embed_dim, kernel_size=(patch_size, patch_size), stride= + patch_size, bias_attr=bias) + + self.text_proj = paddle.nn.Linear(in_features=text_embed_dim, + out_features=embed_dim) + + if use_positional_embeddings or use_learned_positional_embeddings: + persistent = use_learned_positional_embeddings + pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) + self.register_buffer(name='pos_embedding', tensor=pos_embedding, + persistable=persistent) + + def _get_positional_embeddings(self, sample_height: int, sample_width: + int, sample_frames: int) ->paddle.Tensor: + post_patch_height = sample_height // self.patch_size + post_patch_width = sample_width // self.patch_size + post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 + num_patches = (post_patch_height * post_patch_width *post_time_compression_frames) + + pos_embedding = get_3d_sincos_pos_embed( + self.embed_dim, ( + post_patch_width, post_patch_height), + post_time_compression_frames, self.spatial_interpolation_scale, + self.temporal_interpolation_scale + ) + pos_embedding = paddle.to_tensor(data=pos_embedding).flatten(start_axis=0, stop_axis=1) + joint_pos_embedding = paddle.zeros([1, self.max_text_seq_length + num_patches, self.embed_dim]) + joint_pos_embedding[0, self.max_text_seq_length :] = pos_embedding + return joint_pos_embedding + + def forward(self, text_embeds: paddle.Tensor, image_embeds: paddle.Tensor): + """ + Args: + text_embeds (`torch.Tensor`): + Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). + image_embeds (`torch.Tensor`): + Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). + """ + text_embeds = self.text_proj(text_embeds) + # import numpy as np + # text_embeds = paddle.to_tensor(np.load("../CogVideo/inference/text_embeds.npy"), dtype=paddle.float32) + + batch, num_frames, channels, height, width = image_embeds.shape + image_embeds = image_embeds.reshape([-1, channels, height, width]) + image_embeds = self.proj(image_embeds) + image_embeds = image_embeds.reshape([batch, num_frames] + image_embeds.shape[1:]) + image_embeds = image_embeds.flatten(3).transpose([0, 1, 3, 2]) # [batch, num_frames, height x width, channels] + image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] + + embeds = paddle.concat(x=[text_embeds, image_embeds], axis=1 + ).contiguous() + if (self.use_positional_embeddings or self.use_learned_positional_embeddings): + if self.use_learned_positional_embeddings and (self. + sample_width != width or self.sample_height != height): + raise ValueError( + "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'.If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues." + ) + pre_time_compression_frames = (num_frames - 1 + ) * self.temporal_compression_ratio + 1 + if (self.sample_height != height or self.sample_width != width or + self.sample_frames != pre_time_compression_frames): + pos_embedding = self._get_positional_embeddings(height, + width, pre_time_compression_frames) + pos_embedding = pos_embedding.cast(embeds.dtype) + else: + pos_embedding = self.pos_embedding + embeds = embeds + pos_embedding + return embeds + +def get_3d_rotary_pos_embed(embed_dim, crops_coords, grid_size, + temporal_size, theta: int=10000, use_real: bool=True) ->Union[paddle. + Tensor, Tuple[paddle.Tensor, paddle.Tensor]]: + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + use_real (`bool`): + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + start, stop = crops_coords + grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, + dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, + dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, + dtype=np.float32) + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + freqs_t = 1.0 / theta ** (paddle.arange(start=0, end=dim_t, step=2). + astype(dtype='float32') / dim_t) + grid_t = paddle.to_tensor(data=grid_t).astype(dtype='float32') + freqs_t = paddle.einsum('n , f -> n f', grid_t, freqs_t) + freqs_t = freqs_t.repeat_interleave(repeats=2, axis=-1) + freqs_h = 1.0 / theta ** (paddle.arange(start=0, end=dim_h, step=2). + astype(dtype='float32') / dim_h) + freqs_w = 1.0 / theta ** (paddle.arange(start=0, end=dim_w, step=2). + astype(dtype='float32') / dim_w) + grid_h = paddle.to_tensor(data=grid_h).astype(dtype='float32') + grid_w = paddle.to_tensor(data=grid_w).astype(dtype='float32') + freqs_h = paddle.einsum('n , f -> n f', grid_h, freqs_h) + freqs_w = paddle.einsum('n , f -> n f', grid_w, freqs_w) + freqs_h = freqs_h.repeat_interleave(repeats=2, axis=-1) + freqs_w = freqs_w.repeat_interleave(repeats=2, axis=-1) + + def broadcast(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = {len(tuple(t.shape)) for t in tensors} + assert len(shape_lens + ) == 1, 'tensors must all have the same number of dimensions' + shape_len = list(shape_lens)[0] + dim = dim + shape_len if dim < 0 else dim + dims = list(zip(*(list(tuple(t.shape)) for t in tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all([*(len(set(t[1])) <= 2 for t in expandable_dims)] + ), 'invalid dimensions for broadcastable concatenation' + max_dims = [(t[0], max(t[1])) for t in expandable_dims] + expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims] + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*(t[1] for t in expanded_dims))) + tensors = [t[0].expand(shape=t[1]) for t in zip(tensors, + expandable_shapes)] + return paddle.concat(x=tensors, axis=dim) + freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], + freqs_w[None, None, :, :]), dim=-1) + t, h, w, d = tuple(freqs.shape) + freqs = freqs.view(t * h * w, d) + sin = freqs.sin() + cos = freqs.cos() + if use_real: + return cos, sin + else: + freqs_cis = paddle.complex(paddle.ones_like(x=freqs) * paddle.cos( + freqs), paddle.ones_like(x=freqs) * paddle.sin(freqs)) + return freqs_cis \ No newline at end of file diff --git a/ppdiffusers/ppdiffusers/models/normalization.py b/ppdiffusers/ppdiffusers/models/normalization.py index 68a732cc0..e582905f6 100644 --- a/ppdiffusers/ppdiffusers/models/normalization.py +++ b/ppdiffusers/ppdiffusers/models/normalization.py @@ -32,17 +32,45 @@ class AdaLayerNorm(nn.Layer): num_embeddings (`int`): The size of the embeddings dictionary. """ - def __init__(self, embedding_dim: int, num_embeddings: int): + def __init__( + self, + embedding_dim: int, + num_embeddings: Optional[int] = None, + output_dim: Optional[int] = None, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-5, + chunk_dim: int = 0, + ): super().__init__() - self.emb = nn.Embedding(num_embeddings, embedding_dim) + + self.chunk_dim = chunk_dim + output_dim = output_dim or embedding_dim * 2 + + if num_embeddings is not None: + self.emb = nn.Embedding(num_embeddings, embedding_dim) + else: + self.emb = None + self.silu = nn.Silu() - self.linear = nn.Linear(embedding_dim, embedding_dim * 2) - norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False) - self.norm = nn.LayerNorm(embedding_dim, **norm_elementwise_affine_kwargs) + self.linear = nn.Linear(embedding_dim, output_dim) + if norm_elementwise_affine: + norm_elementwise_affine_kwargs = dict(weight_attr=None, bias_attr=None) + else: + norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False) + self.norm = nn.LayerNorm(output_dim // 2, epsilon=norm_eps, **norm_elementwise_affine_kwargs) - def forward(self, x: paddle.Tensor, timestep: paddle.Tensor) -> paddle.Tensor: - emb = self.linear(self.silu(self.emb(timestep))) - scale, shift = paddle.chunk(emb, 2) + def forward(self, x: paddle.Tensor, timestep: Optional[paddle.Tensor] = None, temb: Optional[paddle.Tensor] = None) -> paddle.Tensor: + if self.emb is not None: + temb = self.emb(timestep) + temb = self.linear(self.silu(temb)) + if self.chunk_dim == 1: + # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the + # other if-branch. This branch is specific to CogVideoX for now. + shift, scale = paddle.chunk(temb, 2, axis=1) + shift = shift[:, None, :] + scale = scale[:, None, :] + else: + scale, shift = paddle.chunk(temb, 2) x = self.norm(x) * (1 + scale) + shift return x @@ -224,3 +252,29 @@ def forward(self, hidden_states): epsilon=self.epsilon, begin_norm_axis=2, ) + + +class CogVideoXLayerNormZero(paddle.nn.Layer): + + def __init__(self, conditioning_dim: int, embedding_dim: int, + elementwise_affine: bool=True, eps: float=1e-05, bias: bool=True + ) ->None: + super().__init__() + self.silu = paddle.nn.Silu() + self.linear = paddle.nn.Linear(in_features=conditioning_dim, + out_features=6 * embedding_dim, bias_attr=bias) + self.norm = paddle.nn.LayerNorm(normalized_shape=embedding_dim, + epsilon=eps, weight_attr=elementwise_affine, bias_attr= + elementwise_affine) + + def forward(self, hidden_states: paddle.Tensor, encoder_hidden_states: + paddle.Tensor, temb: paddle.Tensor) ->Tuple[paddle.Tensor, paddle. + Tensor]: + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self + .silu(temb)).chunk(chunks=6, axis=1) + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, : + ] + shift[:, None, :] + encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + + enc_scale)[:, None, :] + enc_shift[:, None, :] + return hidden_states, encoder_hidden_states, gate[:, None, : + ], enc_gate[:, None, :] diff --git a/ppdiffusers/ppdiffusers/models/upsampling.py b/ppdiffusers/ppdiffusers/models/upsampling.py new file mode 100644 index 000000000..f9ac22d6b --- /dev/null +++ b/ppdiffusers/ppdiffusers/models/upsampling.py @@ -0,0 +1,401 @@ +import paddle +from typing import Optional, Tuple +from .normalization import RMSNorm + + +class Upsample1D(paddle.nn.Layer): + """A 1D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 1D layer. + """ + + def __init__(self, channels: int, use_conv: bool=False, + use_conv_transpose: bool=False, out_channels: Optional[int]=None, + name: str='conv'): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.conv = None + if use_conv_transpose: + self.conv = paddle.nn.Conv1DTranspose(in_channels=channels, + out_channels=self.out_channels, kernel_size=4, stride=2, + padding=1) + elif use_conv: + self.conv = paddle.nn.Conv1D(in_channels=self.channels, + out_channels=self.out_channels, kernel_size=3, padding=1) + + def forward(self, inputs: paddle.Tensor) ->paddle.Tensor: + assert tuple(inputs.shape)[1] == self.channels + if self.use_conv_transpose: + return self.conv(inputs) + outputs = paddle.nn.functional.interpolate(x=inputs, scale_factor= + 2.0, mode='nearest') + if self.use_conv: + outputs = self.conv(outputs) + return outputs + + +class Upsample2D(paddle.nn.Layer): + """A 2D upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + name (`str`, default `conv`): + name of the upsampling 2D layer. + """ + + def __init__(self, channels: int, use_conv: bool=False, + use_conv_transpose: bool=False, out_channels: Optional[int]=None, + name: str='conv', kernel_size: Optional[int]=None, padding=1, + norm_type=None, eps=None, elementwise_affine=None, bias=True, + interpolate=True): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + self.interpolate = interpolate + if norm_type == 'ln_norm': + self.norm = paddle.nn.LayerNorm(normalized_shape=channels, + epsilon=eps, weight_attr=elementwise_affine, bias_attr= + elementwise_affine) + elif norm_type == 'rms_norm': + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f'unknown norm_type: {norm_type}') + conv = None + if use_conv_transpose: + if kernel_size is None: + kernel_size = 4 + conv = paddle.nn.Conv2DTranspose(in_channels=channels, + out_channels=self.out_channels, kernel_size=kernel_size, + stride=2, padding=padding, bias_attr=bias) + elif use_conv: + if kernel_size is None: + kernel_size = 3 + conv = paddle.nn.Conv2D(in_channels=self.channels, out_channels + =self.out_channels, kernel_size=kernel_size, padding= + padding, bias_attr=bias) + if name == 'conv': + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states: paddle.Tensor, output_size: Optional[ + int]=None, *args, **kwargs) ->paddle.Tensor: + if len(args) > 0 or kwargs.get('scale', None) is not None: + deprecation_message = ( + 'The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`.' + ) + print('scale', '1.0.0', deprecation_message) + assert tuple(hidden_states.shape)[1] == self.channels + if self.norm is not None: + hidden_states = self.norm(hidden_states.transpose(perm=[0, 2, 3, + 1])).transpose(perm=[0, 3, 1, 2]) + if self.use_conv_transpose: + return self.conv(hidden_states) + dtype = hidden_states.dtype + if dtype == 'bfloat16': + hidden_states = hidden_states.cast('float32') + if tuple(hidden_states.shape)[0] >= 64: + hidden_states = hidden_states.contiguous() + if self.interpolate: + if output_size is None: + hidden_states = paddle.nn.functional.interpolate(x= + hidden_states, scale_factor=2.0, mode='nearest') + else: + hidden_states = paddle.nn.functional.interpolate(x= + hidden_states, size=output_size, mode='nearest') + if dtype == 'bfloat16': + hidden_states = hidden_states.cast(dtype) + if self.use_conv: + if self.name == 'conv': + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + return hidden_states + + +class FirUpsample2D(paddle.nn.Layer): + """A 2D FIR upsampling layer with an optional convolution. + + Parameters: + channels (`int`, optional): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + fir_kernel (`tuple`, default `(1, 3, 3, 1)`): + kernel for the FIR filter. + """ + + def __init__(self, channels: Optional[int]=None, out_channels: Optional + [int]=None, use_conv: bool=False, fir_kernel: Tuple[int, int, int, + int]=(1, 3, 3, 1)): + super().__init__() + out_channels = out_channels if out_channels else channels + if use_conv: + self.Conv2d_0 = paddle.nn.Conv2D(in_channels=channels, + out_channels=out_channels, kernel_size=3, stride=1, padding=1) + self.use_conv = use_conv + self.fir_kernel = fir_kernel + self.out_channels = out_channels + + def _upsample_2d(self, hidden_states: paddle.Tensor, weight: Optional[ + paddle.Tensor]=None, kernel: Optional[paddle.Tensor]=None, factor: + int=2, gain: float=1) ->paddle.Tensor: + """Fused `upsample_2d()` followed by `Conv2d()`. + + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of + arbitrary order. + + Args: + hidden_states (`torch.Tensor`): + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + weight (`torch.Tensor`, *optional*): + Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be + performed by `inChannels = x.shape[0] // numGroups`. + kernel (`torch.Tensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to nearest-neighbor upsampling. + factor (`int`, *optional*): Integer upsampling factor (default: 2). + gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0). + + Returns: + output (`torch.Tensor`): + Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same + datatype as `hidden_states`. + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + kernel = paddle.to_tensor(data=kernel, dtype='float32') + if kernel.ndim == 1: + kernel = paddle.outer(x=kernel, y=kernel) + kernel /= paddle.sum(x=kernel) + kernel = kernel * (gain * factor ** 2) + if self.use_conv: + convH = tuple(weight.shape)[2] + convW = tuple(weight.shape)[3] + inC = tuple(weight.shape)[1] + pad_value = tuple(kernel.shape)[0] - factor - (convW - 1) + stride = factor, factor + output_shape = (tuple(hidden_states.shape)[2] - 1 + ) * factor + convH, (tuple(hidden_states.shape)[3] - 1 + ) * factor + convW + output_padding = output_shape[0] - (tuple(hidden_states.shape)[ + 2] - 1) * stride[0] - convH, output_shape[1] - (tuple( + hidden_states.shape)[3] - 1) * stride[1] - convW + assert output_padding[0] >= 0 and output_padding[1] >= 0 + num_groups = tuple(hidden_states.shape)[1] // inC + weight = paddle.reshape(x=weight, shape=(num_groups, -1, inC, + convH, convW)) + weight = paddle.flip(x=weight, axis=[3, 4]).transpose(perm=[0, + 2, 1, 3, 4]) + weight = paddle.reshape(x=weight, shape=(num_groups * inC, -1, + convH, convW)) + inverse_conv = paddle.nn.functional.conv2d_transpose(x= + hidden_states, weight=weight, stride=stride, output_padding + =output_padding, padding=0) + output = upfirdn2d_native(inverse_conv, paddle.to_tensor(data= + kernel), pad=((pad_value + 1) // + 2 + factor - 1, pad_value // 2 + 1)) + else: + pad_value = tuple(kernel.shape)[0] - factor + output = upfirdn2d_native(hidden_states, paddle.to_tensor(data= + kernel), up=factor, pad=(( + pad_value + 1) // 2 + factor - 1, pad_value // 2)) + return output + + def forward(self, hidden_states: paddle.Tensor) ->paddle.Tensor: + if self.use_conv: + height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, + kernel=self.fir_kernel) + height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) + else: + height = self._upsample_2d(hidden_states, kernel=self. + fir_kernel, factor=2) + return height + + +class KUpsample2D(paddle.nn.Layer): + """A 2D K-upsampling layer. + + Parameters: + pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. + """ + + def __init__(self, pad_mode: str='reflect'): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = paddle.to_tensor(data=[[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 + self.pad = tuple(kernel_1d.shape)[1] // 2 - 1 + self.register_buffer(name='kernel', tensor=kernel_1d.T @ kernel_1d, + persistable=False) + + def forward(self, inputs: paddle.Tensor) ->paddle.Tensor: + inputs = paddle.nn.functional.pad(x=inputs, pad=((self.pad + 1) // + 2,) * 4, mode=self.pad_mode, pad_from_left_axis=False) + weight = paddle.zeros(shape=[tuple(inputs.shape)[1], tuple(inputs. + shape)[1], tuple(self.kernel.shape)[0], tuple(self.kernel.shape + )[1]], dtype=inputs.dtype) + indices = paddle.arange(end=tuple(inputs.shape)[1]) + kernel = self.kernel.cast(weight.dtype)[None, :].expand(shape=[tuple(inputs + .shape)[1], -1, -1]) + weight[indices, indices] = kernel + return paddle.nn.functional.conv2d_transpose(x=inputs, weight= + weight, stride=2, padding=self.pad * 2 + 1) + + +class CogVideoXUpsample3D(paddle.nn.Layer): + """ + A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase. + + Args: + in_channels (`int`): + Number of channels in the input image. + out_channels (`int`): + Number of channels produced by the convolution. + kernel_size (`int`, defaults to `3`): + Size of the convolving kernel. + stride (`int`, defaults to `1`): + Stride of the convolution. + padding (`int`, defaults to `1`): + Padding added to all four sides of the input. + compress_time (`bool`, defaults to `False`): + Whether or not to compress the time dimension. + """ + + def __init__(self, in_channels: int, out_channels: int, kernel_size: + int=3, stride: int=1, padding: int=1, compress_time: bool=False + ) ->None: + super().__init__() + self.conv = paddle.nn.Conv2D(in_channels=in_channels, out_channels= + out_channels, kernel_size=kernel_size, stride=stride, padding= + padding) + self.compress_time = compress_time + + def forward(self, inputs: paddle.Tensor) ->paddle.Tensor: + if self.compress_time: + if tuple(inputs.shape)[2] > 1 and tuple(inputs.shape)[2] % 2 == 1: + x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:] + x_first = paddle.nn.functional.interpolate(x=x_first, + scale_factor=2.0) + x_rest = paddle.nn.functional.interpolate(x=x_rest, + scale_factor=2.0) + x_first = x_first[:, :, None, :, :] + inputs = paddle.concat(x=[x_first, x_rest], axis=2) + elif tuple(inputs.shape)[2] > 1: + inputs = paddle.nn.functional.interpolate(x=inputs, + scale_factor=2.0) + else: + inputs = inputs.squeeze(axis=2) + inputs = paddle.nn.functional.interpolate(x=inputs, + scale_factor=2.0) + inputs = inputs[:, :, None, :, :] + else: + b, c, t, h, w = tuple(inputs.shape) + inputs = inputs.transpose(perm=[0, 2, 1, 3, 4]).reshape([b * t, + c, h, w]) + inputs = paddle.nn.functional.interpolate(x=inputs, + scale_factor=2.0) + inputs = inputs.reshape([b, t, c] + inputs.shape[2:]).transpose(perm=[0, 2, 1, 3, 4]) + b, c, t, h, w = inputs.shape + inputs = inputs.transpose(perm=[0, 2, 1, 3, 4]).reshape([b * t, c, h, w]) + inputs = self.conv(inputs) + inputs = inputs.reshape([b, t] + inputs.shape[1:]).transpose(perm + =[0, 2, 1, 3, 4]) + return inputs + + +def upfirdn2d_native(tensor: paddle.Tensor, kernel: paddle.Tensor, up: int= + 1, down: int=1, pad: Tuple[int, int]=(0, 0)) ->paddle.Tensor: + up_x = up_y = up + down_x = down_y = down + pad_x0 = pad_y0 = pad[0] + pad_x1 = pad_y1 = pad[1] + _, channel, in_h, in_w = tuple(tensor.shape) + tensor = tensor.reshape(-1, in_h, in_w, 1) + _, in_h, in_w, minor = tuple(tensor.shape) + kernel_h, kernel_w = tuple(kernel.shape) + out = tensor.view(-1, in_h, 1, in_w, 1, minor) + out = paddle.nn.functional.pad(x=out, pad=[0, 0, 0, up_x - 1, 0, 0, 0, + up_y - 1], pad_from_left_axis=False) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + out = paddle.nn.functional.pad(x=out, pad=[0, 0, max(pad_x0, 0), max( + pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)], pad_from_left_axis=False) + out = out[:, max(-pad_y0, 0):tuple(out.shape)[1] - max(-pad_y1, 0), max + (-pad_x0, 0):tuple(out.shape)[2] - max(-pad_x1, 0), :] + out = out.transpose(perm=[0, 3, 1, 2]) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + + pad_x0 + pad_x1]) + w = paddle.flip(x=kernel, axis=[0, 1]).view(1, 1, kernel_h, kernel_w) + out = paddle.nn.functional.conv2d(x=out, weight=w) + out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1) + out = out.transpose(perm=[0, 2, 3, 1]) + out = out[:, ::down_y, ::down_x, :] + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + return out.view(-1, channel, out_h, out_w) + + +def upsample_2d(hidden_states: paddle.Tensor, kernel: Optional[paddle. + Tensor]=None, factor: int=2, gain: float=1) ->paddle.Tensor: + """Upsample2D a batch of 2D images with the given filter. + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given + filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is + a: multiple of the upsampling factor. + + Args: + hidden_states (`torch.Tensor`): + Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. + kernel (`torch.Tensor`, *optional*): + FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which + corresponds to nearest-neighbor upsampling. + factor (`int`, *optional*, default to `2`): + Integer upsampling factor. + gain (`float`, *optional*, default to `1.0`): + Scaling factor for signal magnitude (default: 1.0). + + Returns: + output (`torch.Tensor`): + Tensor of the shape `[N, C, H * factor, W * factor]` + """ + assert isinstance(factor, int) and factor >= 1 + if kernel is None: + kernel = [1] * factor + kernel = paddle.to_tensor(data=kernel, dtype='float32') + if kernel.ndim == 1: + kernel = paddle.outer(x=kernel, y=kernel) + kernel /= paddle.sum(x=kernel) + kernel = kernel * (gain * factor ** 2) + pad_value = tuple(kernel.shape)[0] - factor + output = upfirdn2d_native(hidden_states, kernel, up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2)) + return output diff --git a/ppdiffusers/ppdiffusers/pipelines/__init__.py b/ppdiffusers/ppdiffusers/pipelines/__init__.py index b909c884d..d0a858e6c 100644 --- a/ppdiffusers/ppdiffusers/pipelines/__init__.py +++ b/ppdiffusers/ppdiffusers/pipelines/__init__.py @@ -101,6 +101,7 @@ "AudioLDM2UNet2DConditionModel", ] _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] + _import_structure["cogvideo"] = ["CogVideoXPipeline"] _import_structure["controlnet"].extend( [ "BlipDiffusionControlNetPipeline", @@ -375,6 +376,7 @@ AudioLDM2UNet2DConditionModel, ) from .blip_diffusion import BlipDiffusionPipeline + from .cogvideo import CogVideoXPipeline from .controlnet import ( BlipDiffusionControlNetPipeline, PaddleInferStableDiffusionControlNetPipeline, diff --git a/ppdiffusers/ppdiffusers/pipelines/cogvideo/__init__.py b/ppdiffusers/ppdiffusers/pipelines/cogvideo/__init__.py new file mode 100644 index 000000000..8f333d821 --- /dev/null +++ b/ppdiffusers/ppdiffusers/pipelines/cogvideo/__init__.py @@ -0,0 +1,54 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + PPDIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_paddle_available, + is_paddlenlp_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_paddlenlp_available() and is_paddle_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_paddle_and_paddlenlp_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_paddle_and_paddlenlp_objects)) +else: + _import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"] + # _import_structure["pipeline_cogvideox_fun_control"] = ["CogVideoXFunControlPipeline"] + # _import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"] + # _import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"] + +if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT: + try: + if not (is_paddlenlp_available() and is_paddle_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_paddle_and_paddlenlp_objects import * + else: + from .pipeline_cogvideox import CogVideoXPipeline + # from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline + # from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline + # from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) \ No newline at end of file diff --git a/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox.py b/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox.py new file mode 100644 index 000000000..2e472c71f --- /dev/null +++ b/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -0,0 +1,718 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Callable, Dict, List, Optional, Tuple, Union + +import paddle +# from paddlenlp.transformers import T5EncoderModel, T5Tokenizer +from ppdiffusers.transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel +from ...models.embeddings import get_3d_rotary_pos_embed +from ..pipeline_utils import DiffusionPipeline +from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from ...utils import logging, replace_example_docstring +from ...utils.paddle_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import CogVideoXPipelineOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import CogVideoXPipeline + >>> from diffusers.utils import export_to_video + + >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" + >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0] + >>> export_to_video(video, "output.mp4", fps=8) + ``` +""" + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class CogVideoXPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + dtype: Optional[paddle.dtype] = None, + ): + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pd", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pd").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids)[0] + prompt_embeds = prompt_embeds.cast(dtype) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.tile([1, num_videos_per_prompt, 1]) + prompt_embeds = prompt_embeds.reshape([batch_size * num_videos_per_prompt, seq_len, -1]) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[paddle.Tensor] = None, + negative_prompt_embeds: Optional[paddle.Tensor] = None, + max_sequence_length: int = 226, + dtype: Optional[paddle.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, generator, latents=None + ): + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, dtype=dtype) + else: + latents = latents + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: paddle.Tensor) -> paddle.Tensor: + latents = latents.transpose([0, 2, 1, 3, 4]) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + use_real=True, + ) + + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @paddle.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None, + latents: Optional[paddle.Tensor] = None, + prompt_embeds: Optional[paddle.Tensor] = None, + negative_prompt_embeds: Optional[paddle.Tensor] = None, + output_type: str = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + ) -> Union[CogVideoXPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if num_frames > 49: + raise ValueError( + "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation." + ) + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = paddle.concat([negative_prompt_embeds, prompt_embeds], axis=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1)) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + import numpy as np + # latents = paddle.to_tensor(np.load("../CogVideo/inference/latent.npy"), dtype=paddle.bfloat16) + # prompt_embeds = paddle.to_tensor(np.load("../CogVideo/inference/prompt.npy"), dtype=paddle.float32) + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + noise_pred = noise_pred.cast('float32') + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.cast(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return CogVideoXPipelineOutput(frames=video) \ No newline at end of file diff --git a/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_output.py b/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_output.py new file mode 100644 index 000000000..4bf9efaa3 --- /dev/null +++ b/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import paddle + +from ...utils import BaseOutput + + +@dataclass +class CogVideoXPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + frames (`paddle.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: paddle.Tensor \ No newline at end of file diff --git a/ppdiffusers/ppdiffusers/schedulers/__init__.py b/ppdiffusers/ppdiffusers/schedulers/__init__.py index ae1ec2232..e914488e8 100644 --- a/ppdiffusers/ppdiffusers/schedulers/__init__.py +++ b/ppdiffusers/ppdiffusers/schedulers/__init__.py @@ -41,12 +41,14 @@ _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] _import_structure["scheduling_ddim"] = ["DDIMScheduler"] + _import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"] _import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"] _import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"] _import_structure["scheduling_ddpm"] = ["DDPMScheduler"] _import_structure["scheduling_ddpm_parallel"] = ["DDPMParallelScheduler"] _import_structure["scheduling_ddpm_wuerstchen"] = ["DDPMWuerstchenScheduler"] _import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"] + _import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"] _import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"] _import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"] _import_structure["scheduling_dpmsolver_singlestep"] = ["DPMSolverSinglestepScheduler"] @@ -118,12 +120,14 @@ from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler + from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler from .scheduling_ddim_inverse import DDIMInverseScheduler from .scheduling_ddim_parallel import DDIMParallelScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm_parallel import DDPMParallelScheduler from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler from .scheduling_deis_multistep import DEISMultistepScheduler + from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler from .scheduling_dpmsolver_multistep_inverse import ( DPMSolverMultistepInverseScheduler, diff --git a/ppdiffusers/ppdiffusers/schedulers/scheduling_ddim_cogvideox.py b/ppdiffusers/ppdiffusers/schedulers/scheduling_ddim_cogvideox.py new file mode 100644 index 000000000..9cb7fcd85 --- /dev/null +++ b/ppdiffusers/ppdiffusers/schedulers/scheduling_ddim_cogvideox.py @@ -0,0 +1,341 @@ +import paddle +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union +import numpy as np +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + prev_sample: paddle.Tensor + pred_original_sample: Optional[paddle.Tensor] = None + + +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999, + alpha_transform_type='cosine'): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == 'cosine': + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + elif alpha_transform_type == 'exp': + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + else: + raise ValueError( + f'Unsupported alpha_transform_type: {alpha_transform_type}') + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return paddle.to_tensor(data=betas, dtype='float32') + + +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + alphas_bar_sqrt = alphas_cumprod.sqrt() + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + alphas_bar_sqrt -= alphas_bar_sqrt_T + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - + alphas_bar_sqrt_T) + alphas_bar = alphas_bar_sqrt ** 2 + return alphas_bar + + +class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin): + """ + `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__(self, num_train_timesteps: int=1000, beta_start: float= + 0.00085, beta_end: float=0.012, beta_schedule: str='scaled_linear', + trained_betas: Optional[Union[np.ndarray, List[float]]]=None, + clip_sample: bool=True, set_alpha_to_one: bool=True, steps_offset: + int=0, prediction_type: str='epsilon', clip_sample_range: float=1.0, + sample_max_value: float=1.0, timestep_spacing: str='leading', + rescale_betas_zero_snr: bool=False, snr_shift_scale: float=3.0): + if trained_betas is not None: + self.betas = paddle.to_tensor(data=trained_betas, dtype='float32') + elif beta_schedule == 'linear': + self.betas = paddle.linspace(start=beta_start, stop=beta_end, + num=num_train_timesteps, dtype='float32') + elif beta_schedule == 'scaled_linear': + self.betas = paddle.linspace(start=beta_start ** 0.5, stop= + beta_end ** 0.5, num=num_train_timesteps, dtype='float64') ** 2 + elif beta_schedule == 'squaredcos_cap_v2': + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError( + f'{beta_schedule} is not implemented for {self.__class__}') + self.alphas = 1.0 - self.betas + self.alphas_cumprod = paddle.cumprod(x=self.alphas, dim=0) + self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - + snr_shift_scale) * self.alphas_cumprod) + if rescale_betas_zero_snr: + self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod + ) + self.final_alpha_cumprod = paddle.to_tensor(data=1.0 + ) if set_alpha_to_one else self.alphas_cumprod[0] + self.init_noise_sigma = 1.0 + self.num_inference_steps = None + self.timesteps = paddle.to_tensor(data=np.arange(0, + num_train_timesteps)[::-1].copy().astype(np.int64)) + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep + ] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + variance = beta_prod_t_prev / beta_prod_t * (1 - alpha_prod_t / + alpha_prod_t_prev) + return variance + + def scale_model_input(self, sample: paddle.Tensor, timestep: Optional[ + int]=None) ->paddle.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def set_timesteps(self, num_inference_steps: int): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f'`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`: {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle maximal {self.config.num_train_timesteps} timesteps.' + ) + self.num_inference_steps = num_inference_steps + if self.config.timestep_spacing == 'linspace': + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, + num_inference_steps).round()[::-1].copy().astype(np.int64) + elif self.config.timestep_spacing == 'leading': + step_ratio = (self.config.num_train_timesteps // self. + num_inference_steps) + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round( + )[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == 'trailing': + step_ratio = (self.config.num_train_timesteps / self. + num_inference_steps) + timesteps = np.round(np.arange(self.config.num_train_timesteps, + 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + self.timesteps = paddle.to_tensor(data=timesteps) + + def step(self, model_output: paddle.Tensor, timestep: int, sample: + paddle.Tensor, eta: float=0.0, use_clipped_model_output: bool=False, + generator=None, variance_noise: Optional[paddle.Tensor]=None, + return_dict: bool=True) ->Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + prev_timestep = (timestep - self.config.num_train_timesteps // self + .num_inference_steps) + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep + ] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + if self.config.prediction_type == 'epsilon': + pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output + ) / alpha_prod_t ** 0.5 + elif self.config.prediction_type == 'sample': + pred_original_sample = model_output + elif self.config.prediction_type == 'v_prediction': + pred_original_sample = (alpha_prod_t ** 0.5 * sample - + beta_prod_t ** 0.5 * model_output) + else: + raise ValueError( + f'prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or `v_prediction`' + ) + a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 + b_t = alpha_prod_t_prev ** 0.5 - alpha_prod_t ** 0.5 * a_t + prev_sample = a_t * sample + b_t * pred_original_sample + if not return_dict: + return prev_sample, + return DDIMSchedulerOutput(prev_sample=prev_sample, + pred_original_sample=pred_original_sample) + + def add_noise(self, original_samples: paddle.Tensor, noise: paddle. + Tensor, timesteps: paddle.Tensor) ->paddle.Tensor: + alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.place) + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(tuple(sqrt_alpha_prod.shape)) < len(tuple( + original_samples.shape)): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(axis=-1) + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(tuple(sqrt_one_minus_alpha_prod.shape)) < len(tuple( + original_samples.shape)): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze( + axis=-1) + noisy_samples = (sqrt_alpha_prod * original_samples + + sqrt_one_minus_alpha_prod * noise) + return noisy_samples + + def get_velocity(self, sample: paddle.Tensor, noise: paddle.Tensor, + timesteps: paddle.Tensor) ->paddle.Tensor: + alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype) + timesteps = timesteps.to(sample.place) + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(tuple(sqrt_alpha_prod.shape)) < len(tuple(sample.shape)): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(axis=-1) + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(tuple(sqrt_one_minus_alpha_prod.shape)) < len(tuple( + sample.shape)): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze( + axis=-1) + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps diff --git a/ppdiffusers/ppdiffusers/schedulers/scheduling_dpm_cogvideox.py b/ppdiffusers/ppdiffusers/schedulers/scheduling_dpm_cogvideox.py new file mode 100644 index 000000000..9e679db98 --- /dev/null +++ b/ppdiffusers/ppdiffusers/schedulers/scheduling_dpm_cogvideox.py @@ -0,0 +1,453 @@ +import paddle +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union +import numpy as np +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from ..utils.paddle_utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + prev_sample: paddle.Tensor + pred_original_sample: Optional[paddle.Tensor] = None + + +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type='cosine' +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == 'cosine': + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == 'exp': + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError( + f'Unsupported alpha_transform_type: {alpha_transform_type}') + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return paddle.to_tensor(betas, dtype='float32') + + +def rescale_zero_terminal_snr(alphas_cumprod): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + + return alphas_bar + + +class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin): + """ + `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int=1000, + beta_start: float=0.00085, + beta_end: float=0.012, + beta_schedule: str='scaled_linear', + trained_betas: Optional[Union[np.ndarray, List[float]]]=None, + clip_sample: bool=True, + set_alpha_to_one: bool=True, + steps_offset: int=0, + prediction_type: str='epsilon', + clip_sample_range: float=1.0, + sample_max_value: float=1.0, + timestep_spacing: str='leading', + rescale_betas_zero_snr: bool=False, + snr_shift_scale: float=3.0 + ): + if trained_betas is not None: + self.betas = paddle.to_tensor(trained_betas, dtype='float32') + elif beta_schedule == 'linear': + self.betas = paddle.linspace(beta_start, beta_end, num_train_timesteps, dtype='float32') + elif beta_schedule == 'scaled_linear': + self.betas = paddle.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype='float64') ** 2 + elif beta_schedule == 'squaredcos_cap_v2': + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError( + f'{beta_schedule} is not implemented for {self.__class__}') + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = paddle.cumprod(x=self.alphas, dim=0) + + # Modify: SNR shift following SD3 + self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod) + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = paddle.to_tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = paddle.to_tensor(np.arange(0,num_train_timesteps)[::-1].copy().astype(np.int64)) + + def _get_variance( + self, + timestep, + prev_timestep + ): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def scale_model_input( + self, + sample: paddle.Tensor, + timestep: Optional[int]=None + ) ->paddle.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + def set_timesteps(self, num_inference_steps): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + """ + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f'`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`: {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle maximal {self.config.num_train_timesteps} timesteps.' + ) + self.num_inference_steps = num_inference_steps + + if self.config.timestep_spacing == 'linspace': + timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round()[::-1].copy().astype(np.int64) + + elif self.config.timestep_spacing == 'leading': + step_ratio = (self.config.num_train_timesteps // self. num_inference_steps) + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + + elif self.config.timestep_spacing == 'trailing': + step_ratio = (self.config.num_train_timesteps / self.num_inference_steps) + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + self.timesteps = paddle.to_tensor(timesteps) + + def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None): + lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log() + lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log() + h = lamb_next - lamb + + if alpha_prod_t_back is not None: + lamb_previous = ((alpha_prod_t_back / (1 - alpha_prod_t_back)) ** 0.5).log() + h_last = lamb - lamb_previous + r = h_last / h + return h, r, lamb, lamb_next + else: + return h, None, lamb, lamb_next + + def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back): + mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp() + mult2 = (-2 * h).expm1() * alpha_prod_t_prev ** 0.5 + if alpha_prod_t_back is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def step( + self, + model_output: paddle.Tensor, + old_pred_original_sample: paddle.Tensor, + timestep: int, + timestep_back: int, + sample: paddle. + Tensor, eta: float=0.0, + use_clipped_model_output: bool=False, + generator=None, + variance_noise: Optional[paddle.Tensor]=None, + return_dict: bool=False, + ) ->Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t_back = self.alphas_cumprod[timestep_back] if timestep_back is not None else None + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # To make style tests pass, commented out `pred_epsilon` as it is an unused variable + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + # pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + h, r, lamb, lamb_next = self.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back) + mult = list(self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)) + mult_noise = (1 - alpha_prod_t_prev) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5 + + noise = randn_tensor(sample.shape, generator=generator, dtype=sample.dtype) + prev_sample = mult[0] * sample - mult[1] * pred_original_sample + mult_noise * noise + + if old_pred_original_sample is None or prev_timestep < 0: + # Save a network evaluation if all noise levels are 0 or on the first step + return prev_sample, pred_original_sample + else: + denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample + noise = randn_tensor(sample.shape, generator=generator, dtype=sample.dtype) + x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise + + prev_sample = x_advanced + + if not return_dict: + return (prev_sample, pred_original_sample) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + + def add_noise( + self, + original_samples: paddle.Tensor, + noise: paddle.Tensor, + timesteps: paddle.Tensor + ) ->paddle.Tensor: + + alphas_cumprod = self.alphas_cumprod.cast(original_samples.dtype) + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(axis=-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(axis=-1) + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + + return noisy_samples + + def get_velocity( + self, + sample: paddle.Tensor, + noise: paddle.Tensor, + timesteps: paddle.Tensor + ) ->paddle.Tensor: + alphas_cumprod = self.alphas_cumprod.cast(sample.dtype) + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(axis=-1) + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(axis=-1) + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + + return velocity + + def __len__(self): + return self.config.num_train_timesteps diff --git a/ppdiffusers/ppdiffusers/utils/export_utils.py b/ppdiffusers/ppdiffusers/utils/export_utils.py index 130ec3c05..df94ec2e2 100644 --- a/ppdiffusers/ppdiffusers/utils/export_utils.py +++ b/ppdiffusers/ppdiffusers/utils/export_utils.py @@ -161,15 +161,35 @@ def export_to_video( video_writer.write(img) return output_video_path - def export_to_video_2( - video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8 -): + video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10 +) -> str: try: import imageio - except ImportError: - raise ImportError("Please install imageio to export video.run `pip install imageio`") + except: + raise ImportError("imageio is not found") + + try: + imageio.plugins.ffmpeg.get_exe() + except AttributeError: + raise AttributeError( + ( + "Found an existing imageio backend in your environment. Attempting to export video with imageio. \n" + "Unable to find a compatible ffmpeg installation in your environment to use with imageio. Please install via `pip install imageio-ffmpeg" + ) + ) + if output_video_path is None: output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name - imageio.mimsave(output_video_path, video_frames, fps=fps, codec="mpeg4") + if isinstance(video_frames[0], np.ndarray): + video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames] + + elif isinstance(video_frames[0], PIL.Image.Image): + video_frames = [np.array(frame) for frame in video_frames] + + with imageio.get_writer(output_video_path, fps=fps) as writer: + for frame in video_frames: + writer.append_data(frame) + + return output_video_path diff --git a/ppdiffusers/ppdiffusers/video_processor.py b/ppdiffusers/ppdiffusers/video_processor.py new file mode 100644 index 000000000..f8c9e5d2a --- /dev/null +++ b/ppdiffusers/ppdiffusers/video_processor.py @@ -0,0 +1,89 @@ +import paddle +import warnings +from typing import List, Optional, Union +import numpy as np +import PIL +from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist + + +class VideoProcessor(VaeImageProcessor): + """Simple video processor.""" + + def preprocess_video(self, video, height: Optional[int]=None, width: + Optional[int]=None) ->paddle.Tensor: + """ + Preprocesses input video(s). + + Args: + video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`): + The input video. It can be one of the following: + * List of the PIL images. + * List of list of PIL images. + * 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`). + * 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * List of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, + width)`). + * List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`). + * 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width, + num_channels)`. + * 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height, + width)`. + height (`int`, *optional*, defaults to `None`): + The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to + get default height. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get + the default width. + """ + if isinstance(video, list) and isinstance(video[0], np.ndarray + ) and video[0].ndim == 5: + warnings.warn( + 'Passing `video` as a list of 5d np.ndarray is deprecated.Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray' + , FutureWarning) + video = np.concatenate(video, axis=0) + if isinstance(video, list) and isinstance(video[0], paddle.Tensor + ) and video[0].ndim == 5: + warnings.warn( + 'Passing `video` as a list of 5d torch.Tensor is deprecated.Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor' + , FutureWarning) + video = paddle.concat(video, axis=0) + if isinstance(video, (np.ndarray, paddle.Tensor)) and video.ndim == 5: + video = list(video) + elif isinstance(video, list) and is_valid_image(video[0] + ) or is_valid_image_imagelist(video): + video = [video] + elif isinstance(video, list) and is_valid_image_imagelist(video[0]): + video = video + else: + raise ValueError( + 'Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image' + ) + video = paddle.stack(x=[self.preprocess(img, height=height, width= + width) for img in video], axis=0) + video = video.transpose(perm=[0, 2, 1, 3, 4]) + return video + + def postprocess_video(self, video: paddle.Tensor, output_type: str='np' + ) ->Union[np.ndarray, paddle.Tensor, List[PIL.Image.Image]]: + """ + Converts a video tensor to a list of frames for export. + + Args: + video (`torch.Tensor`): The video as a tensor. + output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. + """ + batch_size = tuple(video.shape)[0] + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].transpose(perm=[1, 0, 2, 3]) + batch_output = self.postprocess(batch_vid, output_type) + outputs.append(batch_output) + if output_type == 'np': + outputs = np.stack(outputs) + elif output_type == 'pd': + outputs = paddle.stack(x=outputs) + elif not output_type == 'pil': + raise ValueError( + f"{output_type} does not exist. Please choose one of ['np', 'pd', 'pil']" + ) + return outputs