diff --git a/ppdiffusers/examples/cogvideo/README.md b/ppdiffusers/examples/cogvideo/README.md
new file mode 100644
index 000000000..8e13431d6
--- /dev/null
+++ b/ppdiffusers/examples/cogvideo/README.md
@@ -0,0 +1,10 @@
+# CogVideoX视频生成
+
+```shell
+python infer.py \
+ --prompt "a bear is walking in a zoon" \
+ --model_path THUDM/CogVideoX-2b/ \
+ --generate_type "t2v" \
+ --dtype "float16" \
+ --seed 42
+```
\ No newline at end of file
diff --git a/ppdiffusers/examples/cogvideo/infer.py b/ppdiffusers/examples/cogvideo/infer.py
new file mode 100644
index 000000000..8f5a0bc1c
--- /dev/null
+++ b/ppdiffusers/examples/cogvideo/infer.py
@@ -0,0 +1,186 @@
+"""
+This script demonstrates how to generate a video using the CogVideoX model with the Hugging Face `diffusers` pipeline.
+The script supports different types of video generation, including text-to-video (t2v), image-to-video (i2v),
+and video-to-video (v2v), depending on the input data and different weight.
+
+- text-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b
+- video-to-video: THUDM/CogVideoX-5b or THUDM/CogVideoX-2b
+- image-to-video: THUDM/CogVideoX-5b-I2V
+
+Running the Script:
+To run the script, use the following command with appropriate arguments:
+
+```bash
+$ python cli_demo.py --prompt "A girl riding a bike." --model_path THUDM/CogVideoX-5b --generate_type "t2v"
+```
+
+Additional options are available to specify the model path, guidance scale, number of inference steps, video generation type, and output paths.
+"""
+
+import argparse
+from typing import Literal
+
+import paddle
+from ppdiffusers import (
+ CogVideoXPipeline,
+ CogVideoXDDIMScheduler,
+ CogVideoXDPMScheduler,
+ # CogVideoXImageToVideoPipeline,
+ # CogVideoXVideoToVideoPipeline,
+)
+
+from ppdiffusers.utils import export_to_video_2
+
+
+def generate_video(
+ prompt: str,
+ model_path: str,
+ lora_path: str = None,
+ lora_rank: int = 128,
+ output_path: str = "./output.mp4",
+ image_or_video_path: str = "",
+ num_inference_steps: int = 50,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: int = 1,
+ dtype: paddle.dtype = paddle.bfloat16,
+ generate_type: str = Literal["t2v", "i2v", "v2v"], # i2v: image to video, v2v: video to video
+ seed: int = 42,
+):
+ """
+ Generates a video based on the given prompt and saves it to the specified path.
+
+ Parameters:
+ - prompt (str): The description of the video to be generated.
+ - model_path (str): The path of the pre-trained model to be used.
+ - lora_path (str): The path of the LoRA weights to be used.
+ - lora_rank (int): The rank of the LoRA weights.
+ - output_path (str): The path where the generated video will be saved.
+ - num_inference_steps (int): Number of steps for the inference process. More steps can result in better quality.
+ - guidance_scale (float): The scale for classifier-free guidance. Higher values can lead to better alignment with the prompt.
+ - num_videos_per_prompt (int): Number of videos to generate per prompt.
+ - dtype (paddle.dtype): The data type for computation (default is paddle.bfloat16).
+ - generate_type (str): The type of video generation (e.g., 't2v', 'i2v', 'v2v').·
+ - seed (int): The seed for reproducibility.
+ """
+
+ # 1. Load the pre-trained CogVideoX pipeline with the specified precision (bfloat16).
+ # add device_map="balanced" in the from_pretrained function and remove the enable_model_cpu_offload()
+ # function to use Multi GPUs.
+
+ image = None
+ video = None
+
+ if generate_type == "i2v":
+ # pipe = CogVideoXImageToVideoPipeline.from_pretrained(model_path, dtype=dtype)
+ # image = load_image(image=image_or_video_path)
+ raise NotImplementedError
+ elif generate_type == "t2v":
+ pipe = CogVideoXPipeline.from_pretrained(model_path, paddle_dtype=dtype)
+ else:
+ # pipe = CogVideoXVideoToVideoPipeline.from_pretrained(model_path, dtype=dtype)
+ # video = load_video(image_or_video_path)
+ raise NotImplementedError
+
+ # If you're using with lora, add this code
+ if lora_path:
+ pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name="test_1")
+ pipe.fuse_lora(lora_scale=1 / lora_rank)
+
+ # 2. Set Scheduler.
+ # Can be changed to `CogVideoXDPMScheduler` or `CogVideoXDDIMScheduler`.
+ # We recommend using `CogVideoXDDIMScheduler` for CogVideoX-2B.
+ # using `CogVideoXDPMScheduler` for CogVideoX-5B / CogVideoX-5B-I2V.
+
+ pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
+ # pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
+
+ pipe.vae.enable_slicing()
+ pipe.vae.enable_tiling()
+
+ # 4. Generate the video frames based on the prompt.
+ # `num_frames` is the Number of frames to generate.
+ # This is the default value for 6 seconds video and 8 fps and will plus 1 frame for the first frame and 49 frames.
+ if generate_type == "i2v":
+ # video_generate = pipe(
+ # prompt=prompt,
+ # image=image, # The path of the image to be used as the background of the video
+ # num_videos_per_prompt=num_videos_per_prompt, # Number of videos to generate per prompt
+ # num_inference_steps=num_inference_steps, # Number of inference steps
+ # num_frames=49, # Number of frames to generate,changed to 49 for diffusers version `0.30.3` and after.
+ # use_dynamic_cfg=True, # This id used for DPM Sechduler, for DDIM scheduler, it should be False
+ # guidance_scale=guidance_scale,
+ # generator=paddle.seed(seed), # Set the seed for reproducibility
+ # ).frames[0]
+ raise NotImplementedError
+ elif generate_type == "t2v":
+ video_generate = pipe(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ num_inference_steps=num_inference_steps,
+ num_frames=49,
+ use_dynamic_cfg=True,
+ guidance_scale=guidance_scale,
+ generator=paddle.Generator().manual_seed(seed),
+ ).frames[0]
+ else:
+ # video_generate = pipe(
+ # prompt=prompt,
+ # video=video, # The path of the video to be used as the background of the video
+ # num_videos_per_prompt=num_videos_per_prompt,
+ # num_inference_steps=num_inference_steps,
+ # # num_frames=49,
+ # use_dynamic_cfg=True,
+ # guidance_scale=guidance_scale,
+ # generator=paddle.seed(seed), # Set the seed for reproducibility
+ # ).frames[0]
+ raise NotImplementedError
+ # 5. Export the generated frames to a video file. fps must be 8 for original video.
+ export_to_video_2(video_generate, output_path, fps=8)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Generate a video from a text prompt using CogVideoX")
+ parser.add_argument("--prompt", type=str, required=True, help="The description of the video to be generated")
+ parser.add_argument(
+ "--image_or_video_path",
+ type=str,
+ default=None,
+ help="The path of the image to be used as the background of the video",
+ )
+ parser.add_argument(
+ "--model_path", type=str, default="THUDM/CogVideoX-5b", help="The path of the pre-trained model to be used"
+ )
+ parser.add_argument("--lora_path", type=str, default=None, help="The path of the LoRA weights to be used")
+ parser.add_argument("--lora_rank", type=int, default=128, help="The rank of the LoRA weights")
+ parser.add_argument(
+ "--output_path", type=str, default="./output.mp4", help="The path where the generated video will be saved"
+ )
+ parser.add_argument("--guidance_scale", type=float, default=6.0, help="The scale for classifier-free guidance")
+ parser.add_argument(
+ "--num_inference_steps", type=int, default=50, help="Number of steps for the inference process"
+ )
+ parser.add_argument("--num_videos_per_prompt", type=int, default=1, help="Number of videos to generate per prompt")
+ parser.add_argument(
+ "--generate_type", type=str, default="t2v", help="The type of video generation (e.g., 't2v', 'i2v', 'v2v')"
+ )
+ parser.add_argument(
+ "--dtype", type=str, default="bfloat16", help="The data type for computation (e.g., 'float16' or 'bfloat16')"
+ )
+ parser.add_argument("--seed", type=int, default=42, help="The seed for reproducibility")
+
+ args = parser.parse_args()
+ dtype = paddle.float16 if args.dtype == "float16" else paddle.bfloat16
+ generate_video(
+ prompt=args.prompt,
+ model_path=args.model_path,
+ lora_path=args.lora_path,
+ lora_rank=args.lora_rank,
+ output_path=args.output_path,
+ image_or_video_path=args.image_or_video_path,
+ num_inference_steps=args.num_inference_steps,
+ guidance_scale=args.guidance_scale,
+ num_videos_per_prompt=args.num_videos_per_prompt,
+ dtype=dtype,
+ generate_type=args.generate_type,
+ seed=args.seed,
+ )
diff --git a/ppdiffusers/ppdiffusers/__init__.py b/ppdiffusers/ppdiffusers/__init__.py
index d47fc7a38..5d21f1f1c 100644
--- a/ppdiffusers/ppdiffusers/__init__.py
+++ b/ppdiffusers/ppdiffusers/__init__.py
@@ -110,8 +110,10 @@
[
"AsymmetricAutoencoderKL",
"AutoencoderKL",
+ "AutoencoderKLCogVideoX",
"AutoencoderKLTemporalDecoder",
"AutoencoderTiny",
+ "CogVideoXTransformer3DModel",
"ConsistencyDecoderVAE",
"ControlNetModel",
"Kandinsky3UNet",
@@ -182,6 +184,8 @@
_import_structure["schedulers"].extend(
[
"CMStochasticIterativeScheduler",
+ "CogVideoXDDIMScheduler",
+ "CogVideoXDPMScheduler",
"DDIMInverseScheduler",
"DDIMParallelScheduler",
"DDIMScheduler",
@@ -266,6 +270,7 @@
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"CLIPImageProjection",
+ "CogVideoXPipeline",
"CycleDiffusionPipeline",
"IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline",
@@ -486,9 +491,11 @@
from .models import ( # new add
AsymmetricAutoencoderKL,
AutoencoderKL,
+ AutoencoderKLCogVideoX,
AutoencoderKL_imgtovideo,
AutoencoderKLTemporalDecoder,
AutoencoderTiny,
+ CogVideoXTransformer3DModel,
ConsistencyDecoderVAE,
ControlNetModel,
DiTLLaMA2DModel,
@@ -554,6 +561,8 @@
)
from .schedulers import (
CMStochasticIterativeScheduler,
+ CogVideoXDDIMScheduler,
+ CogVideoXDPMScheduler,
DDIMInverseScheduler,
DDIMParallelScheduler,
DDIMScheduler,
@@ -619,6 +628,7 @@
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
CLIPImageProjection,
+ CogVideoXPipeline,
CycleDiffusionPipeline,
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
diff --git a/ppdiffusers/ppdiffusers/callbacks.py b/ppdiffusers/ppdiffusers/callbacks.py
new file mode 100644
index 000000000..38542407e
--- /dev/null
+++ b/ppdiffusers/ppdiffusers/callbacks.py
@@ -0,0 +1,156 @@
+from typing import Any, Dict, List
+
+from .configuration_utils import ConfigMixin, register_to_config
+from .utils import CONFIG_NAME
+
+
+class PipelineCallback(ConfigMixin):
+ """
+ Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing
+ custom callbacks and ensures that all callbacks have a consistent interface.
+
+ Please implement the following:
+ `tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to
+ include
+ variables listed in the `._callback_tensor_inputs` attribute of your pipeline class.
+ `callback_fn`: This method defines the core functionality of your callback.
+ """
+
+ config_name = CONFIG_NAME
+
+ @register_to_config
+ def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
+ super().__init__()
+
+ if (cutoff_step_ratio is None and cutoff_step_index is None) or (
+ cutoff_step_ratio is not None and cutoff_step_index is not None
+ ):
+ raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.")
+
+ if cutoff_step_ratio is not None and (
+ not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0)
+ ):
+ raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")
+
+ @property
+ def tensor_inputs(self) -> List[str]:
+ raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")
+
+ def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
+ raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")
+
+ def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
+ return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)
+
+
+class MultiPipelineCallbacks:
+ """
+ This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and
+ provides a unified interface for calling all of them.
+ """
+
+ def __init__(self, callbacks: List[PipelineCallback]):
+ self.callbacks = callbacks
+
+ @property
+ def tensor_inputs(self) -> List[str]:
+ return [input for callback in self.callbacks for input in callback.tensor_inputs]
+
+ def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
+ """
+ Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
+ """
+ for callback in self.callbacks:
+ callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs)
+
+ return callback_kwargs
+
+
+class SDCFGCutoffCallback(PipelineCallback):
+ """
+ Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
+ `cutoff_step_index`), this callback will disable the CFG.
+
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
+ """
+
+ tensor_inputs = ["prompt_embeds"]
+
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
+ cutoff_step_ratio = self.config.cutoff_step_ratio
+ cutoff_step_index = self.config.cutoff_step_index
+
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
+ cutoff_step = (
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
+ )
+
+ if step_index == cutoff_step:
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
+
+ pipeline._guidance_scale = 0.0
+
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
+ return callback_kwargs
+
+
+class SDXLCFGCutoffCallback(PipelineCallback):
+ """
+ Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
+ `cutoff_step_index`), this callback will disable the CFG.
+
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
+ """
+
+ tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
+
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
+ cutoff_step_ratio = self.config.cutoff_step_ratio
+ cutoff_step_index = self.config.cutoff_step_index
+
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
+ cutoff_step = (
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
+ )
+
+ if step_index == cutoff_step:
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
+
+ add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
+ add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
+
+ add_time_ids = callback_kwargs[self.tensor_inputs[2]]
+ add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
+
+ pipeline._guidance_scale = 0.0
+
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
+ callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
+ callback_kwargs[self.tensor_inputs[2]] = add_time_ids
+ return callback_kwargs
+
+
+class IPAdapterScaleCutoffCallback(PipelineCallback):
+ """
+ Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`.
+
+ Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step.
+ """
+
+ tensor_inputs = []
+
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
+ cutoff_step_ratio = self.config.cutoff_step_ratio
+ cutoff_step_index = self.config.cutoff_step_index
+
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
+ cutoff_step = (
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
+ )
+
+ if step_index == cutoff_step:
+ pipeline.set_ip_adapter_scale(0.0)
+ return callback_kwargs
diff --git a/ppdiffusers/ppdiffusers/image_processor.py b/ppdiffusers/ppdiffusers/image_processor.py
index 75ac521e8..41b8c3daa 100644
--- a/ppdiffusers/ppdiffusers/image_processor.py
+++ b/ppdiffusers/ppdiffusers/image_processor.py
@@ -650,3 +650,22 @@ def preprocess(
depth = self.binarize(depth)
return rgb, depth
+
+
+def is_valid_image(image):
+ return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, paddle.Tensor)) and image.ndim in (2, 3)
+
+
+def is_valid_image_imagelist(images):
+ # check if the image input is one of the supported formats for image and image list:
+ # it can be either one of below 3
+ # (1) a 4d pytorch tensor or numpy array,
+ # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
+ # (3) a list of valid image
+ if isinstance(images, (np.ndarray, paddle.Tensor)) and images.ndim == 4:
+ return True
+ elif is_valid_image(images):
+ return True
+ elif isinstance(images, list):
+ return all(is_valid_image(image) for image in images)
+ return False
\ No newline at end of file
diff --git a/ppdiffusers/ppdiffusers/models/__init__.py b/ppdiffusers/ppdiffusers/models/__init__.py
index 81bea4dcb..ed5187022 100644
--- a/ppdiffusers/ppdiffusers/models/__init__.py
+++ b/ppdiffusers/ppdiffusers/models/__init__.py
@@ -22,6 +22,7 @@
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
+ _import_structure["autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
@@ -32,6 +33,7 @@
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformer_sd3"] = ["SD3Transformer2DModel"]
+ _import_structure["cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"]
_import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unet_1d"] = ["UNet1DModel"]
_import_structure["unet_2d"] = ["UNet2DModel"]
@@ -64,6 +66,7 @@
from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
+ from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
@@ -87,6 +90,7 @@
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
+ from .cogvideox_transformer_3d import CogVideoXTransformer3DModel
from .controlnet_sd3 import SD3ControlNetModel
from .controlnet_sd3 import SD3MultiControlNetModel
from .transformer_temporal import TransformerTemporalModel
diff --git a/ppdiffusers/ppdiffusers/models/activations.py b/ppdiffusers/ppdiffusers/models/activations.py
index 44db4a55d..42d56d455 100644
--- a/ppdiffusers/ppdiffusers/models/activations.py
+++ b/ppdiffusers/ppdiffusers/models/activations.py
@@ -66,9 +66,9 @@ class GELU(nn.Layer):
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
"""
- def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool=True):
super().__init__()
- self.proj = nn.Linear(dim_in, dim_out)
+ self.proj = nn.Linear(dim_in, dim_out, bias_attr=bias)
self.approximate = approximate
def gelu(self, gate: paddle.Tensor) -> paddle.Tensor:
@@ -89,11 +89,11 @@ class GEGLU(nn.Layer):
dim_out (`int`): The number of channels in the output.
"""
- def __init__(self, dim_in: int, dim_out: int):
+ def __init__(self, dim_in: int, dim_out: int, bias: bool=True):
super().__init__()
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
- self.proj = linear_cls(dim_in, dim_out * 2)
+ self.proj = linear_cls(dim_in, dim_out * 2, bias_attr=bias)
def gelu(self, gate: paddle.Tensor) -> paddle.Tensor:
return F.gelu(gate)
@@ -114,9 +114,9 @@ class ApproximateGELU(nn.Layer):
dim_out (`int`): The number of channels in the output.
"""
- def __init__(self, dim_in: int, dim_out: int):
+ def __init__(self, dim_in: int, dim_out: int, bias: bool=True):
super().__init__()
- self.proj = nn.Linear(dim_in, dim_out)
+ self.proj = nn.Linear(dim_in, dim_out, bias_attr=bias)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
x = self.proj(x)
diff --git a/ppdiffusers/ppdiffusers/models/attention.py b/ppdiffusers/ppdiffusers/models/attention.py
index 8b5a9d027..ef3a810d0 100644
--- a/ppdiffusers/ppdiffusers/models/attention.py
+++ b/ppdiffusers/ppdiffusers/models/attention.py
@@ -641,20 +641,23 @@ def __init__(
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
+ inner_dim=None,
+ bias: bool = True,
):
super().__init__()
- inner_dim = int(dim * mult)
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
if activation_fn == "gelu":
- act_fn = GELU(dim, inner_dim)
+ act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
- act_fn = GELU(dim, inner_dim, approximate="tanh")
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
- act_fn = GEGLU(dim, inner_dim)
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
- act_fn = ApproximateGELU(dim, inner_dim)
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
self.net = nn.LayerList([])
# project in
@@ -662,7 +665,7 @@ def __init__(
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
- self.net.append(linear_cls(inner_dim, dim_out))
+ self.net.append(linear_cls(inner_dim, dim_out, bias_attr=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
diff --git a/ppdiffusers/ppdiffusers/models/attention_processor.py b/ppdiffusers/ppdiffusers/models/attention_processor.py
index c93c55ae6..3d6f1659b 100644
--- a/ppdiffusers/ppdiffusers/models/attention_processor.py
+++ b/ppdiffusers/ppdiffusers/models/attention_processor.py
@@ -91,6 +91,7 @@ def __init__(
upcast_softmax: bool = False,
cross_attention_norm: Optional[str] = None,
cross_attention_norm_num_groups: int = 32,
+ qk_norm: Optional[str] = None,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None,
@@ -149,6 +150,15 @@ def __init__(
else:
self.spatial_norm = None
+ if qk_norm is None:
+ self.norm_q = None
+ self.norm_k = None
+ elif qk_norm == "layer_norm":
+ self.norm_q = nn.LayerNorm(dim_head, epsilon=eps)
+ self.norm_k = nn.LayerNorm(dim_head, epsilon=eps)
+ else:
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
+
if cross_attention_norm is None:
self.norm_cross = None
elif cross_attention_norm == "layer_norm":
@@ -2091,6 +2101,156 @@ def __call__(
return out
+class CogVideoXAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
+ query and key vectors, but does not include spatial normalization.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: paddle.Tensor,
+ encoder_hidden_states: paddle.Tensor,
+ attention_mask: Optional[paddle.Tensor] = None,
+ image_rotary_emb: Optional[paddle.Tensor] = None,
+ ) -> paddle.Tensor:
+ text_seq_length = encoder_hidden_states.shape[1]
+
+ hidden_states = paddle.concat([encoder_hidden_states, hidden_states], axis=1)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.reshape([batch_size, attn.heads, -1, attention_mask.shape[-1]])
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.reshape([batch_size, -1, attn.heads, head_dim]).transpose([0, 2, 1, 3])
+ key = key.reshape([batch_size, -1, attn.heads, head_dim]).transpose([0, 2, 1, 3])
+ value = value.reshape([batch_size, -1, attn.heads, head_dim]).transpose([0, 2, 1, 3])
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
+ if not attn.is_cross_attention:
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
+
+ # NOTE: There is diff between paddle's and torch's sdpa
+ # paddle needs input: [batch_size, seq_len, num_heads, head_dim]
+ # torch needs input: [batch_size, num_heads, seq_len, head_dim]
+ hidden_states = F.scaled_dot_product_attention(
+ query.transpose([0, 2, 1, 3]),
+ key.transpose([0, 2, 1, 3]),
+ value.transpose([0, 2, 1, 3]),
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False
+ )
+
+ hidden_states = hidden_states.reshape([batch_size, -1, attn.heads * head_dim])
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, hidden_states.shape[1] - text_seq_length], axis=1
+ )
+ return hidden_states, encoder_hidden_states
+
+
+class FusedCogVideoXAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
+ query and key vectors, but does not include spatial normalization.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: paddle.Tensor,
+ encoder_hidden_states: paddle.Tensor,
+ attention_mask: Optional[paddle.Tensor] = None,
+ image_rotary_emb: Optional[paddle.Tensor] = None,
+ ) -> paddle.Tensor:
+ text_seq_length = encoder_hidden_states.shape[1]
+
+ hidden_states = paddle.concat([encoder_hidden_states, hidden_states], axis=1)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.reshape([batch_size, attn.heads, -1, attention_mask.shape[-1]])
+
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = paddle.split(qkv, split_size, axis=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.reshape([batch_size, -1, attn.heads, head_dim]).premute([0, 2, 1, 3])
+ key = key.reshape([batch_size, -1, attn.heads, head_dim]).premute([0, 2, 1, 3])
+ value = value.reshape([batch_size, -1, attn.heads, head_dim]).premute([0, 2, 1, 3])
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
+ if not attn.is_cross_attention:
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.premute([0, 2, 1, 3]).reshape([batch_size, -1, attn.heads * head_dim])
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, hidden_states.shape[1] - text_seq_length], axis=1
+ )
+ return hidden_states, encoder_hidden_states
+
+
LoRAAttnProcessor2_5 = LoRAXFormersAttnProcessor
AttnAddedKVProcessor2_5 = XFormersAttnAddedKVProcessor
AttnProcessor2_5 = XFormersAttnProcessor
diff --git a/ppdiffusers/ppdiffusers/models/autoencoder_kl_cogvideox.py b/ppdiffusers/ppdiffusers/models/autoencoder_kl_cogvideox.py
new file mode 100644
index 000000000..85a4a3acc
--- /dev/null
+++ b/ppdiffusers/ppdiffusers/models/autoencoder_kl_cogvideox.py
@@ -0,0 +1,1034 @@
+import paddle
+from typing import Optional, Tuple, Union
+import numpy as np
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import logging
+from ..utils.accelerate_utils import apply_forward_hook
+from .activations import get_activation
+from .downsampling import CogVideoXDownsample3D
+from .modeling_outputs import AutoencoderKLOutput
+from .modeling_utils import ModelMixin
+from .upsampling import CogVideoXUpsample3D
+from .vae import DecoderOutput, DiagonalGaussianDistribution
+logger = logging.get_logger(__name__)
+
+
+class CogVideoXSafeConv3d(paddle.nn.Conv3D):
+ """
+ A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
+ """
+
+ def forward(self, input: paddle.Tensor) ->paddle.Tensor:
+ memory_count = paddle.prod(x=paddle.to_tensor(data=tuple(input.shape))
+ ).item() * 2 / 1024 ** 3
+ if memory_count > 2:
+ kernel_size = self.kernel_size[0]
+ part_num = int(memory_count / 2) + 1
+ input_chunks = paddle.chunk(x=input, chunks=part_num, axis=2)
+ if kernel_size > 1:
+ input_chunks = [input_chunks[0]] + [paddle.concat(x=(
+ input_chunks[i - 1][:, :, -kernel_size + 1:],
+ input_chunks[i]), axis=2) for i in range(1, len(
+ input_chunks))]
+ output_chunks = []
+ for input_chunk in input_chunks:
+ output_chunks.append(super().forward(input_chunk))
+ output = paddle.concat(x=output_chunks, axis=2)
+ return output
+ else:
+ return super().forward(input)
+
+
+class CogVideoXCausalConv3d(paddle.nn.Layer):
+ """A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
+
+ Args:
+ in_channels (`int`): Number of channels in the input tensor.
+ out_channels (`int`): Number of output channels produced by the convolution.
+ kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
+ stride (`int`, defaults to `1`): Stride of the convolution.
+ dilation (`int`, defaults to `1`): Dilation rate of the convolution.
+ pad_mode (`str`, defaults to `"constant"`): Padding mode.
+ """
+
+ def __init__(self, in_channels: int, out_channels: int, kernel_size:
+ Union[int, Tuple[int, int, int]], stride: int=1, dilation: int=1,
+ pad_mode: str='constant'):
+ super().__init__()
+ if isinstance(kernel_size, int):
+ kernel_size = (kernel_size,) * 3
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
+ self.pad_mode = pad_mode
+ time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
+ height_pad = height_kernel_size // 2
+ width_pad = width_kernel_size // 2
+ self.height_pad = height_pad
+ self.width_pad = width_pad
+ self.time_pad = time_pad
+ self.time_causal_padding = (width_pad, width_pad, height_pad,
+ height_pad, time_pad, 0)
+ self.temporal_dim = 2
+ self.time_kernel_size = time_kernel_size
+ stride = stride, 1, 1
+ dilation = dilation, 1, 1
+ self.conv = CogVideoXSafeConv3d(in_channels=in_channels,
+ out_channels=out_channels, kernel_size=kernel_size, stride=
+ stride, dilation=dilation)
+ self.conv_cache = None
+
+ def fake_context_parallel_forward(self, inputs: paddle.Tensor
+ ) ->paddle.Tensor:
+ kernel_size = self.time_kernel_size
+ if kernel_size > 1:
+ cached_inputs = [self.conv_cache
+ ] if self.conv_cache is not None else [inputs[:, :, :1]] * (
+ kernel_size - 1)
+ inputs = paddle.concat(x=cached_inputs + [inputs], axis=2)
+ return inputs
+
+ def _clear_fake_context_parallel_cache(self):
+ del self.conv_cache
+ self.conv_cache = None
+
+ def forward(self, inputs: paddle.Tensor) ->paddle.Tensor:
+ inputs = self.fake_context_parallel_forward(inputs)
+ self._clear_fake_context_parallel_cache()
+ self.conv_cache = inputs[:, :, -self.time_kernel_size + 1:].clone()
+ padding_2d = (self.width_pad, self.width_pad, self.height_pad, self
+ .height_pad)
+ inputs = paddle.nn.functional.pad(x=inputs, pad=padding_2d, mode=
+ 'constant', value=0, pad_from_left_axis=False)
+ output = self.conv(inputs)
+ return output
+
+
+class CogVideoXSpatialNorm3D(paddle.nn.Layer):
+ """
+ Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
+ to 3D-video like data.
+
+ CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
+
+ Args:
+ f_channels (`int`):
+ The number of channels for input to group normalization layer, and output of the spatial norm layer.
+ zq_channels (`int`):
+ The number of channels for the quantized vector as described in the paper.
+ groups (`int`):
+ Number of groups to separate the channels into for group normalization.
+ """
+
+ def __init__(self, f_channels: int, zq_channels: int, groups: int=32):
+ super().__init__()
+ self.norm_layer = paddle.nn.GroupNorm(num_channels=f_channels,
+ num_groups=groups, epsilon=1e-06, weight_attr=True, bias_attr=True)
+ self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels,
+ kernel_size=1, stride=1)
+ self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels,
+ kernel_size=1, stride=1)
+
+ def forward(self, f: paddle.Tensor, zq: paddle.Tensor) ->paddle.Tensor:
+ if tuple(f.shape)[2] > 1 and tuple(f.shape)[2] % 2 == 1:
+ f_first, f_rest = f[:, :, :1], f[:, :, 1:]
+ f_first_size, f_rest_size = tuple(f_first.shape)[-3:], tuple(f_rest
+ .shape)[-3:]
+ z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
+ z_first = paddle.nn.functional.interpolate(x=z_first, size=
+ f_first_size)
+ z_rest = paddle.nn.functional.interpolate(x=z_rest, size=
+ f_rest_size)
+ zq = paddle.concat(x=[z_first, z_rest], axis=2)
+ else:
+ zq = paddle.nn.functional.interpolate(x=zq, size=tuple(f.shape)
+ [-3:])
+ norm_f = self.norm_layer(f)
+ new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
+ return new_f
+
+
+class CogVideoXResnetBlock3D(paddle.nn.Layer):
+ """
+ A 3D ResNet block used in the CogVideoX model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ temb_channels (`int`, defaults to `512`):
+ Number of time embedding channels.
+ groups (`int`, defaults to `32`):
+ Number of groups to separate the channels into for group normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ non_linearity (`str`, defaults to `"swish"`):
+ Activation function to use.
+ conv_shortcut (bool, defaults to `False`):
+ Whether or not to use a convolution shortcut.
+ spatial_norm_dim (`int`, *optional*):
+ The dimension to use for spatial norm if it is to be used instead of group norm.
+ pad_mode (str, defaults to `"first"`):
+ Padding mode.
+ """
+
+ def __init__(self, in_channels: int, out_channels: Optional[int]=None,
+ dropout: float=0.0, temb_channels: int=512, groups: int=32, eps:
+ float=1e-06, non_linearity: str='swish', conv_shortcut: bool=False,
+ spatial_norm_dim: Optional[int]=None, pad_mode: str='first'):
+ super().__init__()
+ out_channels = out_channels or in_channels
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.nonlinearity = get_activation(non_linearity)
+ self.use_conv_shortcut = conv_shortcut
+ if spatial_norm_dim is None:
+ self.norm1 = paddle.nn.GroupNorm(num_channels=in_channels,
+ num_groups=groups, epsilon=eps)
+ self.norm2 = paddle.nn.GroupNorm(num_channels=out_channels,
+ num_groups=groups, epsilon=eps)
+ else:
+ self.norm1 = CogVideoXSpatialNorm3D(f_channels=in_channels,
+ zq_channels=spatial_norm_dim, groups=groups)
+ self.norm2 = CogVideoXSpatialNorm3D(f_channels=out_channels,
+ zq_channels=spatial_norm_dim, groups=groups)
+ self.conv1 = CogVideoXCausalConv3d(in_channels=in_channels,
+ out_channels=out_channels, kernel_size=3, pad_mode=pad_mode)
+ if temb_channels > 0:
+ self.temb_proj = paddle.nn.Linear(in_features=temb_channels,
+ out_features=out_channels)
+ self.dropout = paddle.nn.Dropout(p=dropout)
+ self.conv2 = CogVideoXCausalConv3d(in_channels=out_channels,
+ out_channels=out_channels, kernel_size=3, pad_mode=pad_mode)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = CogVideoXCausalConv3d(in_channels=
+ in_channels, out_channels=out_channels, kernel_size=3,
+ pad_mode=pad_mode)
+ else:
+ self.conv_shortcut = CogVideoXSafeConv3d(in_channels=
+ in_channels, out_channels=out_channels, kernel_size=1,
+ stride=1, padding=0)
+
+ def forward(self, inputs: paddle.Tensor, temb: Optional[paddle.Tensor]=
+ None, zq: Optional[paddle.Tensor]=None) ->paddle.Tensor:
+ hidden_states = inputs
+ if zq is not None:
+ hidden_states = self.norm1(hidden_states, zq)
+ else:
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+ if temb is not None:
+ hidden_states = hidden_states + self.temb_proj(self.
+ nonlinearity(temb))[:, :, None, None, None]
+ if zq is not None:
+ hidden_states = self.norm2(hidden_states, zq)
+ else:
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+ if self.in_channels != self.out_channels:
+ inputs = self.conv_shortcut(inputs)
+ hidden_states = hidden_states + inputs
+ return hidden_states
+
+
+class CogVideoXDownBlock3D(paddle.nn.Layer):
+ """
+ A downsampling block used in the CogVideoX model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ temb_channels (`int`, defaults to `512`):
+ Number of time embedding channels.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ resnet_groups (`int`, defaults to `32`):
+ Number of groups to separate the channels into for group normalization.
+ add_downsample (`bool`, defaults to `True`):
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
+ compress_time (`bool`, defaults to `False`):
+ Whether or not to downsample across temporal dimension.
+ pad_mode (str, defaults to `"first"`):
+ Padding mode.
+ """
+ _supports_gradient_checkpointing = True
+
+ def __init__(self, in_channels: int, out_channels: int, temb_channels:
+ int, dropout: float=0.0, num_layers: int=1, resnet_eps: float=1e-06,
+ resnet_act_fn: str='swish', resnet_groups: int=32, add_downsample:
+ bool=True, downsample_padding: int=0, compress_time: bool=False,
+ pad_mode: str='first'):
+ super().__init__()
+ resnets = []
+ for i in range(num_layers):
+ in_channel = in_channels if i == 0 else out_channels
+ resnets.append(CogVideoXResnetBlock3D(in_channels=in_channel,
+ out_channels=out_channels, dropout=dropout, temb_channels=
+ temb_channels, groups=resnet_groups, eps=resnet_eps,
+ non_linearity=resnet_act_fn, pad_mode=pad_mode))
+ self.resnets = paddle.nn.LayerList(sublayers=resnets)
+ self.downsamplers = None
+ if add_downsample:
+ self.downsamplers = paddle.nn.LayerList(sublayers=[
+ CogVideoXDownsample3D(out_channels, out_channels, padding=
+ downsample_padding, compress_time=compress_time)])
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: paddle.Tensor, temb: Optional[paddle.
+ Tensor]=None, zq: Optional[paddle.Tensor]=None) ->paddle.Tensor:
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+
+ def create_forward(*inputs):
+ return module(*inputs)
+ return create_forward
+ hidden_states = paddle.distributed.fleet.utils.recompute(
+ create_custom_forward(resnet), hidden_states, temb, zq)
+ else:
+ hidden_states = resnet(hidden_states, temb, zq)
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+ return hidden_states
+
+
+class CogVideoXMidBlock3D(paddle.nn.Layer):
+ """
+ A middle block used in the CogVideoX model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ temb_channels (`int`, defaults to `512`):
+ Number of time embedding channels.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ resnet_groups (`int`, defaults to `32`):
+ Number of groups to separate the channels into for group normalization.
+ spatial_norm_dim (`int`, *optional*):
+ The dimension to use for spatial norm if it is to be used instead of group norm.
+ pad_mode (str, defaults to `"first"`):
+ Padding mode.
+ """
+ _supports_gradient_checkpointing = True
+
+ def __init__(self, in_channels: int, temb_channels: int, dropout: float
+ =0.0, num_layers: int=1, resnet_eps: float=1e-06, resnet_act_fn:
+ str='swish', resnet_groups: int=32, spatial_norm_dim: Optional[int]
+ =None, pad_mode: str='first'):
+ super().__init__()
+ resnets = []
+ for _ in range(num_layers):
+ resnets.append(CogVideoXResnetBlock3D(in_channels=in_channels,
+ out_channels=in_channels, dropout=dropout, temb_channels=
+ temb_channels, groups=resnet_groups, eps=resnet_eps,
+ spatial_norm_dim=spatial_norm_dim, non_linearity=
+ resnet_act_fn, pad_mode=pad_mode))
+ self.resnets = paddle.nn.LayerList(sublayers=resnets)
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: paddle.Tensor, temb: Optional[paddle.
+ Tensor]=None, zq: Optional[paddle.Tensor]=None) ->paddle.Tensor:
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+
+ def create_forward(*inputs):
+ return module(*inputs)
+ return create_forward
+ hidden_states = paddle.distributed.fleet.utils.recompute(
+ create_custom_forward(resnet), hidden_states, temb, zq)
+ else:
+ hidden_states = resnet(hidden_states, temb, zq)
+ return hidden_states
+
+
+class CogVideoXUpBlock3D(paddle.nn.Layer):
+ """
+ An upsampling block used in the CogVideoX model.
+
+ Args:
+ in_channels (`int`):
+ Number of input channels.
+ out_channels (`int`, *optional*):
+ Number of output channels. If None, defaults to `in_channels`.
+ temb_channels (`int`, defaults to `512`):
+ Number of time embedding channels.
+ dropout (`float`, defaults to `0.0`):
+ Dropout rate.
+ num_layers (`int`, defaults to `1`):
+ Number of resnet layers.
+ resnet_eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ resnet_act_fn (`str`, defaults to `"swish"`):
+ Activation function to use.
+ resnet_groups (`int`, defaults to `32`):
+ Number of groups to separate the channels into for group normalization.
+ spatial_norm_dim (`int`, defaults to `16`):
+ The dimension to use for spatial norm if it is to be used instead of group norm.
+ add_upsample (`bool`, defaults to `True`):
+ Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
+ compress_time (`bool`, defaults to `False`):
+ Whether or not to downsample across temporal dimension.
+ pad_mode (str, defaults to `"first"`):
+ Padding mode.
+ """
+
+ def __init__(self, in_channels: int, out_channels: int, temb_channels:
+ int, dropout: float=0.0, num_layers: int=1, resnet_eps: float=1e-06,
+ resnet_act_fn: str='swish', resnet_groups: int=32, spatial_norm_dim:
+ int=16, add_upsample: bool=True, upsample_padding: int=1,
+ compress_time: bool=False, pad_mode: str='first'):
+ super().__init__()
+ resnets = []
+ for i in range(num_layers):
+ in_channel = in_channels if i == 0 else out_channels
+ resnets.append(CogVideoXResnetBlock3D(in_channels=in_channel,
+ out_channels=out_channels, dropout=dropout, temb_channels=
+ temb_channels, groups=resnet_groups, eps=resnet_eps,
+ non_linearity=resnet_act_fn, spatial_norm_dim=
+ spatial_norm_dim, pad_mode=pad_mode))
+ self.resnets = paddle.nn.LayerList(sublayers=resnets)
+ self.upsamplers = None
+ if add_upsample:
+ self.upsamplers = paddle.nn.LayerList(sublayers=[
+ CogVideoXUpsample3D(out_channels, out_channels, padding=
+ upsample_padding, compress_time=compress_time)])
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: paddle.Tensor, temb: Optional[paddle.
+ Tensor]=None, zq: Optional[paddle.Tensor]=None) ->paddle.Tensor:
+ """Forward method of the `CogVideoXUpBlock3D` class."""
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+
+ def create_forward(*inputs):
+ return module(*inputs)
+ return create_forward
+ hidden_states = paddle.distributed.fleet.utils.recompute(
+ create_custom_forward(resnet), hidden_states, temb, zq)
+ else:
+ hidden_states = resnet(hidden_states, temb, zq)
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+ return hidden_states
+
+
+class CogVideoXEncoder3D(paddle.nn.Layer):
+ """
+ The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
+
+ Args:
+ in_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ out_channels (`int`, *optional*, defaults to 3):
+ The number of output channels.
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
+ options.
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
+ The number of output channels for each block.
+ act_fn (`str`, *optional*, defaults to `"silu"`):
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
+ layers_per_block (`int`, *optional*, defaults to 2):
+ The number of layers per block.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups for normalization.
+ """
+ _supports_gradient_checkpointing = True
+
+ def __init__(self, in_channels: int=3, out_channels: int=16,
+ down_block_types: Tuple[str, ...]=('CogVideoXDownBlock3D',
+ 'CogVideoXDownBlock3D', 'CogVideoXDownBlock3D',
+ 'CogVideoXDownBlock3D'), block_out_channels: Tuple[int, ...]=(128,
+ 256, 256, 512), layers_per_block: int=3, act_fn: str='silu',
+ norm_eps: float=1e-06, norm_num_groups: int=32, dropout: float=0.0,
+ pad_mode: str='first', temporal_compression_ratio: float=4):
+ super().__init__()
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
+ self.conv_in = CogVideoXCausalConv3d(in_channels,
+ block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
+ self.down_blocks = paddle.nn.LayerList(sublayers=[])
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ compress_time = i < temporal_compress_level
+ if down_block_type == 'CogVideoXDownBlock3D':
+ down_block = CogVideoXDownBlock3D(in_channels=input_channel,
+ out_channels=output_channel, temb_channels=0, dropout=
+ dropout, num_layers=layers_per_block, resnet_eps=
+ norm_eps, resnet_act_fn=act_fn, resnet_groups=
+ norm_num_groups, add_downsample=not is_final_block,
+ compress_time=compress_time)
+ else:
+ raise ValueError(
+ 'Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`'
+ )
+ self.down_blocks.append(down_block)
+ self.mid_block = CogVideoXMidBlock3D(in_channels=block_out_channels
+ [-1], temb_channels=0, dropout=dropout, num_layers=2,
+ resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=
+ norm_num_groups, pad_mode=pad_mode)
+ self.norm_out = paddle.nn.GroupNorm(num_groups=norm_num_groups,
+ num_channels=block_out_channels[-1], epsilon=1e-06)
+ self.conv_act = paddle.nn.Silu()
+ self.conv_out = CogVideoXCausalConv3d(block_out_channels[-1], 2 *
+ out_channels, kernel_size=3, pad_mode=pad_mode)
+ self.gradient_checkpointing = False
+
+ def forward(self, sample: paddle.Tensor, temb: Optional[paddle.Tensor]=None
+ ) ->paddle.Tensor:
+ """The forward method of the `CogVideoXEncoder3D` class."""
+ hidden_states = self.conv_in(sample)
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+ for down_block in self.down_blocks:
+ hidden_states = paddle.distributed.fleet.utils.recompute(
+ create_custom_forward(down_block), hidden_states, temb,
+ None)
+ hidden_states = paddle.distributed.fleet.utils.recompute(
+ create_custom_forward(self.mid_block), hidden_states, temb,
+ None)
+ else:
+ for down_block in self.down_blocks:
+ hidden_states = down_block(hidden_states, temb, None)
+ hidden_states = self.mid_block(hidden_states, temb, None)
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ return hidden_states
+
+
+class CogVideoXDecoder3D(paddle.nn.Layer):
+ """
+ The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
+ sample.
+
+ Args:
+ in_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ out_channels (`int`, *optional*, defaults to 3):
+ The number of output channels.
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
+ The number of output channels for each block.
+ act_fn (`str`, *optional*, defaults to `"silu"`):
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
+ layers_per_block (`int`, *optional*, defaults to 2):
+ The number of layers per block.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups for normalization.
+ """
+ _supports_gradient_checkpointing = True
+
+ def __init__(self, in_channels: int=16, out_channels: int=3,
+ up_block_types: Tuple[str, ...]=('CogVideoXUpBlock3D',
+ 'CogVideoXUpBlock3D', 'CogVideoXUpBlock3D', 'CogVideoXUpBlock3D'),
+ block_out_channels: Tuple[int, ...]=(128, 256, 256, 512),
+ layers_per_block: int=3, act_fn: str='silu', norm_eps: float=1e-06,
+ norm_num_groups: int=32, dropout: float=0.0, pad_mode: str='first',
+ temporal_compression_ratio: float=4):
+ super().__init__()
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ self.conv_in = CogVideoXCausalConv3d(in_channels,
+ reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
+ self.mid_block = CogVideoXMidBlock3D(in_channels=
+ reversed_block_out_channels[0], temb_channels=0, num_layers=2,
+ resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=
+ norm_num_groups, spatial_norm_dim=in_channels, pad_mode=pad_mode)
+ self.up_blocks = paddle.nn.LayerList(sublayers=[])
+ output_channel = reversed_block_out_channels[0]
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+ compress_time = i < temporal_compress_level
+ if up_block_type == 'CogVideoXUpBlock3D':
+ up_block = CogVideoXUpBlock3D(in_channels=
+ prev_output_channel, out_channels=output_channel,
+ temb_channels=0, dropout=dropout, num_layers=
+ layers_per_block + 1, resnet_eps=norm_eps,
+ resnet_act_fn=act_fn, resnet_groups=norm_num_groups,
+ spatial_norm_dim=in_channels, add_upsample=not
+ is_final_block, compress_time=compress_time, pad_mode=
+ pad_mode)
+ prev_output_channel = output_channel
+ else:
+ raise ValueError(
+ 'Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`'
+ )
+ self.up_blocks.append(up_block)
+ self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[
+ -1], in_channels, groups=norm_num_groups)
+ self.conv_act = paddle.nn.Silu()
+ self.conv_out = CogVideoXCausalConv3d(reversed_block_out_channels[-
+ 1], out_channels, kernel_size=3, pad_mode=pad_mode)
+ self.gradient_checkpointing = False
+
+ def forward(self, sample: paddle.Tensor, temb: Optional[paddle.Tensor]=None
+ ) ->paddle.Tensor:
+ """The forward method of the `CogVideoXDecoder3D` class."""
+ hidden_states = self.conv_in(sample)
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+
+ def custom_forward(*inputs):
+ return module(*inputs)
+ return custom_forward
+ hidden_states = paddle.distributed.fleet.utils.recompute(
+ create_custom_forward(self.mid_block), hidden_states, temb,
+ sample)
+ for up_block in self.up_blocks:
+ hidden_states = paddle.distributed.fleet.utils.recompute(
+ create_custom_forward(up_block), hidden_states, temb,
+ sample)
+ else:
+ hidden_states = self.mid_block(hidden_states, temb, sample)
+ for up_block in self.up_blocks:
+ hidden_states = up_block(hidden_states, temb, sample)
+ hidden_states = self.norm_out(hidden_states, sample)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ return hidden_states
+
+
+class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin):
+ """
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
+ [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+ Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+ Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
+ Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ force_upcast (`bool`, *optional*, default to `True`):
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
+ """
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ['CogVideoXResnetBlock3D']
+
+ @register_to_config
+ def __init__(self, in_channels: int=3, out_channels: int=3,
+ down_block_types: Tuple[str]=('CogVideoXDownBlock3D',
+ 'CogVideoXDownBlock3D', 'CogVideoXDownBlock3D',
+ 'CogVideoXDownBlock3D'), up_block_types: Tuple[str]=(
+ 'CogVideoXUpBlock3D', 'CogVideoXUpBlock3D', 'CogVideoXUpBlock3D',
+ 'CogVideoXUpBlock3D'), block_out_channels: Tuple[int]=(128, 256,
+ 256, 512), latent_channels: int=16, layers_per_block: int=3, act_fn:
+ str='silu', norm_eps: float=1e-06, norm_num_groups: int=32,
+ temporal_compression_ratio: float=4, sample_height: int=480,
+ sample_width: int=720, scaling_factor: float=1.15258426,
+ shift_factor: Optional[float]=None, latents_mean: Optional[Tuple[
+ float]]=None, latents_std: Optional[Tuple[float]]=None,
+ force_upcast: float=True, use_quant_conv: bool=False,
+ use_post_quant_conv: bool=False):
+ super().__init__()
+ self.encoder = CogVideoXEncoder3D(in_channels=in_channels,
+ out_channels=latent_channels, down_block_types=down_block_types,
+ block_out_channels=block_out_channels, layers_per_block=
+ layers_per_block, act_fn=act_fn, norm_eps=norm_eps,
+ norm_num_groups=norm_num_groups, temporal_compression_ratio=
+ temporal_compression_ratio)
+ self.decoder = CogVideoXDecoder3D(in_channels=latent_channels,
+ out_channels=out_channels, up_block_types=up_block_types,
+ block_out_channels=block_out_channels, layers_per_block=
+ layers_per_block, act_fn=act_fn, norm_eps=norm_eps,
+ norm_num_groups=norm_num_groups, temporal_compression_ratio=
+ temporal_compression_ratio)
+ self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 *
+ out_channels, 1) if use_quant_conv else None
+ self.post_quant_conv = CogVideoXSafeConv3d(out_channels,
+ out_channels, 1) if use_post_quant_conv else None
+ self.use_slicing = False
+ self.use_tiling = False
+ self.num_latent_frames_batch_size = 2
+ self.num_sample_frames_batch_size = 8
+ self.tile_sample_min_height = sample_height // 2
+ self.tile_sample_min_width = sample_width // 2
+ self.tile_latent_min_height = int(self.tile_sample_min_height / 2 **
+ (len(self.config.block_out_channels) - 1))
+ self.tile_latent_min_width = int(self.tile_sample_min_width / 2 **
+ (len(self.config.block_out_channels) - 1))
+ self.tile_overlap_factor_height = 1 / 6
+ self.tile_overlap_factor_width = 1 / 5
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
+ module.gradient_checkpointing = value
+
+ def _clear_fake_context_parallel_cache(self):
+ for name, module in self.named_sublayers():
+ if isinstance(module, CogVideoXCausalConv3d):
+ logger.debug(
+ f'Clearing fake Context Parallel cache for layer: {name}')
+ module._clear_fake_context_parallel_cache()
+
+ def enable_tiling(self, tile_sample_min_height: Optional[int]=None,
+ tile_sample_min_width: Optional[int]=None,
+ tile_overlap_factor_height: Optional[float]=None,
+ tile_overlap_factor_width: Optional[float]=None) ->None:
+ """
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_overlap_factor_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
+ value might cause more tiles to be processed leading to slow down of the decoding process.
+ tile_overlap_factor_width (`int`, *optional*):
+ The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
+ are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
+ value might cause more tiles to be processed leading to slow down of the decoding process.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = (tile_sample_min_height or self.
+ tile_sample_min_height)
+ self.tile_sample_min_width = (tile_sample_min_width or self.
+ tile_sample_min_width)
+ self.tile_latent_min_height = int(self.tile_sample_min_height / 2 **
+ (len(self.config.block_out_channels) - 1))
+ self.tile_latent_min_width = int(self.tile_sample_min_width / 2 **
+ (len(self.config.block_out_channels) - 1))
+ self.tile_overlap_factor_height = (tile_overlap_factor_height or
+ self.tile_overlap_factor_height)
+ self.tile_overlap_factor_width = (tile_overlap_factor_width or self
+ .tile_overlap_factor_width)
+
+ def disable_tiling(self) ->None:
+ """
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) ->None:
+ """
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) ->None:
+ """
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def _encode(self, x: paddle.Tensor) ->paddle.Tensor:
+ batch_size, num_channels, num_frames, height, width = tuple(x.shape)
+ if self.use_tiling and (width > self.tile_sample_min_width or
+ height > self.tile_sample_min_height):
+ return self.tiled_encode(x)
+ frame_batch_size = self.num_sample_frames_batch_size
+ num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
+ enc = []
+ for i in range(num_batches):
+ remaining_frames = num_frames % frame_batch_size
+ start_frame = frame_batch_size * i + (0 if i == 0 else
+ remaining_frames)
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
+ x_intermediate = x[:, :, start_frame:end_frame]
+ x_intermediate = self.encoder(x_intermediate)
+ if self.quant_conv is not None:
+ x_intermediate = self.quant_conv(x_intermediate)
+ enc.append(x_intermediate)
+ self._clear_fake_context_parallel_cache()
+ enc = paddle.concat(x=enc, axis=2)
+ return enc
+
+ @apply_forward_hook
+ def encode(self, x: paddle.Tensor, return_dict: bool=True) ->Union[
+ AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and tuple(x.shape)[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = paddle.concat(x=encoded_slices)
+ else:
+ h = self._encode(x)
+ posterior = DiagonalGaussianDistribution(h)
+ if not return_dict:
+ return posterior,
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: paddle.Tensor, return_dict: bool=True) ->Union[
+ DecoderOutput, paddle.Tensor]:
+ batch_size, num_channels, num_frames, height, width = tuple(z.shape)
+ if self.use_tiling and (width > self.tile_latent_min_width or
+ height > self.tile_latent_min_height):
+ return self.tiled_decode(z, return_dict=return_dict)
+ frame_batch_size = self.num_latent_frames_batch_size
+ num_batches = num_frames // frame_batch_size
+ dec = []
+ for i in range(num_batches):
+ remaining_frames = num_frames % frame_batch_size
+ start_frame = frame_batch_size * i + (0 if i == 0 else
+ remaining_frames)
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
+ z_intermediate = z[:, :, start_frame:end_frame]
+ if self.post_quant_conv is not None:
+ z_intermediate = self.post_quant_conv(z_intermediate)
+ z_intermediate = self.decoder(z_intermediate)
+ dec.append(z_intermediate)
+ self._clear_fake_context_parallel_cache()
+ dec = paddle.concat(x=dec, axis=2)
+ if not return_dict:
+ return dec,
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(self, z: paddle.Tensor, return_dict: bool=True) ->Union[
+ DecoderOutput, paddle.Tensor]:
+ """
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and tuple(z.shape)[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z
+ .split(1)]
+ decoded = paddle.concat(x=decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+ if not return_dict:
+ return decoded,
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: paddle.Tensor, b: paddle.Tensor, blend_extent: int
+ ) ->paddle.Tensor:
+ blend_extent = min(tuple(a.shape)[3], tuple(b.shape)[3], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y /
+ blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
+ return b
+
+ def blend_h(self, a: paddle.Tensor, b: paddle.Tensor, blend_extent: int
+ ) ->paddle.Tensor:
+ blend_extent = min(tuple(a.shape)[4], tuple(b.shape)[4], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x /
+ blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
+ return b
+
+ def tiled_encode(self, x: paddle.Tensor) ->paddle.Tensor:
+ """Encode a batch of images using a tiled encoder.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
+ output, but they should be much less noticeable.
+
+ Args:
+ x (`torch.Tensor`): Input batch of videos.
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded videos.
+ """
+ batch_size, num_channels, num_frames, height, width = tuple(x.shape)
+ overlap_height = int(self.tile_sample_min_height * (1 - self.
+ tile_overlap_factor_height))
+ overlap_width = int(self.tile_sample_min_width * (1 - self.
+ tile_overlap_factor_width))
+ blend_extent_height = int(self.tile_latent_min_height * self.
+ tile_overlap_factor_height)
+ blend_extent_width = int(self.tile_latent_min_width * self.
+ tile_overlap_factor_width)
+ row_limit_height = self.tile_latent_min_height - blend_extent_height
+ row_limit_width = self.tile_latent_min_width - blend_extent_width
+ frame_batch_size = self.num_sample_frames_batch_size
+ rows = []
+ for i in range(0, height, overlap_height):
+ row = []
+ for j in range(0, width, overlap_width):
+ num_batches = (num_frames // frame_batch_size if num_frames >
+ 1 else 1)
+ time = []
+ for k in range(num_batches):
+ remaining_frames = num_frames % frame_batch_size
+ start_frame = frame_batch_size * k + (0 if k == 0 else
+ remaining_frames)
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
+ tile = x[:, :, start_frame:end_frame, i:i + self.
+ tile_sample_min_height, j:j + self.
+ tile_sample_min_width]
+ tile = self.encoder(tile)
+ if self.quant_conv is not None:
+ tile = self.quant_conv(tile)
+ time.append(tile)
+ self._clear_fake_context_parallel_cache()
+ row.append(paddle.concat(x=time, axis=2))
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile,
+ blend_extent_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
+ result_row.append(tile[:, :, :, :row_limit_height, :
+ row_limit_width])
+ result_rows.append(paddle.concat(x=result_row, axis=4))
+ enc = paddle.concat(x=result_rows, axis=3)
+ return enc
+
+ def tiled_decode(self, z: paddle.Tensor, return_dict: bool=True) ->Union[
+ DecoderOutput, paddle.Tensor]:
+ """
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ batch_size, num_channels, num_frames, height, width = tuple(z.shape)
+ overlap_height = int(self.tile_latent_min_height * (1 - self.
+ tile_overlap_factor_height))
+ overlap_width = int(self.tile_latent_min_width * (1 - self.
+ tile_overlap_factor_width))
+ blend_extent_height = int(self.tile_sample_min_height * self.
+ tile_overlap_factor_height)
+ blend_extent_width = int(self.tile_sample_min_width * self.
+ tile_overlap_factor_width)
+ row_limit_height = self.tile_sample_min_height - blend_extent_height
+ row_limit_width = self.tile_sample_min_width - blend_extent_width
+ frame_batch_size = self.num_latent_frames_batch_size
+ rows = []
+ for i in range(0, height, overlap_height):
+ row = []
+ for j in range(0, width, overlap_width):
+ num_batches = num_frames // frame_batch_size
+ time = []
+ for k in range(num_batches):
+ remaining_frames = num_frames % frame_batch_size
+ start_frame = frame_batch_size * k + (0 if k == 0 else
+ remaining_frames)
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
+ tile = z[:, :, start_frame:end_frame, i:i + self.
+ tile_latent_min_height, j:j + self.
+ tile_latent_min_width]
+ if self.post_quant_conv is not None:
+ tile = self.post_quant_conv(tile)
+ tile = self.decoder(tile)
+ time.append(tile)
+ self._clear_fake_context_parallel_cache()
+ row.append(paddle.concat(x=time, axis=2))
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile,
+ blend_extent_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
+ result_row.append(tile[:, :, :, :row_limit_height, :
+ row_limit_width])
+ result_rows.append(paddle.concat(x=result_row, axis=4))
+ dec = paddle.concat(x=result_rows, axis=3)
+ if not return_dict:
+ return dec,
+ return DecoderOutput(sample=dec)
+
+ def forward(self, sample: paddle.Tensor, sample_posterior: bool=False,
+ return_dict: bool=True, generator: Optional[paddle.seed]=None
+ ) ->Union[paddle.Tensor, paddle.Tensor]:
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ if not return_dict:
+ return dec,
+ return dec
diff --git a/ppdiffusers/ppdiffusers/models/cogvideox_transformer_3d.py b/ppdiffusers/ppdiffusers/models/cogvideox_transformer_3d.py
new file mode 100644
index 000000000..b6db5c965
--- /dev/null
+++ b/ppdiffusers/ppdiffusers/models/cogvideox_transformer_3d.py
@@ -0,0 +1,394 @@
+import paddle
+from typing import Any, Dict, Optional, Tuple, Union
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import logging
+from ..utils.paddle_utils import maybe_allow_in_graph
+from .attention import Attention, FeedForward
+from .attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
+from .embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
+from .modeling_outputs import Transformer2DModelOutput
+from .modeling_utils import ModelMixin
+from .normalization import AdaLayerNorm, CogVideoXLayerNormZero
+logger = logging.get_logger(__name__)
+
+
+@maybe_allow_in_graph
+class CogVideoXBlock(paddle.nn.Layer):
+ """
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
+
+ Parameters:
+ dim (`int`):
+ The number of channels in the input and output.
+ num_attention_heads (`int`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`):
+ The number of channels in each head.
+ time_embed_dim (`int`):
+ The number of channels in timestep embedding.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to be used in feed-forward.
+ attention_bias (`bool`, defaults to `False`):
+ Whether or not to use bias in attention projection layers.
+ qk_norm (`bool`, defaults to `True`):
+ Whether or not to use normalization after query and key projections in Attention.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_eps (`float`, defaults to `1e-5`):
+ Epsilon value for normalization layers.
+ final_dropout (`bool` defaults to `False`):
+ Whether to apply a final dropout after the last feed-forward layer.
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
+ ff_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Feed-forward layer.
+ attention_out_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in Attention output projection layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ time_embed_dim: int,
+ dropout: float=0.0,
+ activation_fn: str='gelu-approximate',
+ attention_bias: bool=False,
+ qk_norm: bool=True,
+ norm_elementwise_affine: bool=True,
+ norm_eps: float=1e-05,
+ final_dropout: bool=True,
+ ff_inner_dim: Optional[int]=None,
+ ff_bias: bool=True,
+ attention_out_bias: bool=True
+ ):
+ super().__init__()
+
+ # 1. self attention
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ qk_norm='layer_norm' if qk_norm else None,
+ eps=1e-06,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0()
+ )
+
+ # 2. feed forward
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias
+ )
+
+ def forward(
+ self,
+ hidden_states: paddle.Tensor,
+ encoder_hidden_states:paddle.Tensor,
+ temb: paddle.Tensor,
+ image_rotary_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]]=None
+ ) ->paddle.Tensor:
+ text_seq_length = encoder_hidden_states.shape[1]
+
+ # norm and modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # attention
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb
+ )
+
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
+
+ # norm and modulate
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
+ hidden_states, encoder_hidden_states, temb
+ )
+
+ # feed forward
+ norm_hidden_states = paddle.concat([norm_encoder_hidden_states, norm_hidden_states], axis=1)
+ ff_output = self.ff(norm_hidden_states)
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = (encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length])
+
+ return hidden_states, encoder_hidden_states
+
+
+class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
+ """
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
+
+ Parameters:
+ num_attention_heads (`int`, defaults to `30`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `64`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ time_embed_dim (`int`, defaults to `512`):
+ Output dimension of timestep embeddings.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ num_layers (`int`, defaults to `30`):
+ The number of layers of Transformer blocks to use.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability to use.
+ attention_bias (`bool`, defaults to `True`):
+ Whether or not to use bias in the attention projection layers.
+ sample_width (`int`, defaults to `90`):
+ The width of the input latents.
+ sample_height (`int`, defaults to `60`):
+ The height of the input latents.
+ sample_frames (`int`, defaults to `49`):
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
+ patch_size (`int`, defaults to `2`):
+ The size of the patches to use in the patch embedding layer.
+ temporal_compression_ratio (`int`, defaults to `4`):
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
+ max_text_seq_length (`int`, defaults to `226`):
+ The maximum sequence length of the input text embeddings.
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
+ Activation function to use in feed-forward.
+ timestep_activation_fn (`str`, defaults to `"silu"`):
+ Activation function to use when generating the timestep embeddings.
+ norm_elementwise_affine (`bool`, defaults to `True`):
+ Whether or not to use elementwise affine in normalization layers.
+ norm_eps (`float`, defaults to `1e-5`):
+ The epsilon value to use in normalization layers.
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
+ """
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(self, num_attention_heads: int=30, attention_head_dim: int
+ =64, in_channels: int=16, out_channels: Optional[int]=16,
+ flip_sin_to_cos: bool=True, freq_shift: int=0, time_embed_dim: int=
+ 512, text_embed_dim: int=4096, num_layers: int=30, dropout: float=
+ 0.0, attention_bias: bool=True, sample_width: int=90, sample_height:
+ int=60, sample_frames: int=49, patch_size: int=2,
+ temporal_compression_ratio: int=4, max_text_seq_length: int=226,
+ activation_fn: str='gelu-approximate', timestep_activation_fn: str=
+ 'silu', norm_elementwise_affine: bool=True, norm_eps: float=1e-05,
+ spatial_interpolation_scale: float=1.875,
+ temporal_interpolation_scale: float=1.0,
+ use_rotary_positional_embeddings: bool=False,
+ use_learned_positional_embeddings: bool=False):
+ super().__init__()
+ inner_dim = num_attention_heads * attention_head_dim
+ if (not use_rotary_positional_embeddings and
+ use_learned_positional_embeddings):
+ raise ValueError(
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional embeddings. If you're using a custom model and/or believe this should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
+ )
+ self.patch_embed = CogVideoXPatchEmbed(patch_size=patch_size,
+ in_channels=in_channels, embed_dim=inner_dim, text_embed_dim=
+ text_embed_dim, bias=True, sample_width=sample_width,
+ sample_height=sample_height, sample_frames=sample_frames,
+ temporal_compression_ratio=temporal_compression_ratio,
+ max_text_seq_length=max_text_seq_length,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ use_positional_embeddings=not use_rotary_positional_embeddings,
+ use_learned_positional_embeddings=use_learned_positional_embeddings
+ )
+ self.embedding_dropout = paddle.nn.Dropout(p=dropout)
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim,
+ timestep_activation_fn)
+ self.transformer_blocks = paddle.nn.LayerList(sublayers=[
+ CogVideoXBlock(dim=inner_dim, num_attention_heads=
+ num_attention_heads, attention_head_dim=attention_head_dim,
+ time_embed_dim=time_embed_dim, dropout=dropout, activation_fn=
+ activation_fn, attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine, norm_eps=
+ norm_eps) for _ in range(num_layers)])
+ self.norm_final = paddle.nn.LayerNorm(normalized_shape=inner_dim,
+ epsilon=norm_eps, weight_attr=norm_elementwise_affine,
+ bias_attr=norm_elementwise_affine)
+ self.norm_out = AdaLayerNorm(embedding_dim=time_embed_dim,
+ output_dim=2 * inner_dim, norm_elementwise_affine=
+ norm_elementwise_affine, norm_eps=norm_eps, chunk_dim=1)
+ self.proj_out = paddle.nn.Linear(in_features=inner_dim,
+ out_features=patch_size * patch_size * out_channels)
+ self.gradient_checkpointing = False
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ self.gradient_checkpointing = value
+
+ @property
+ def attn_processors(self) ->Dict[str, AttentionProcessor]:
+ """
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: paddle.nn.Layer,
+ processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, 'get_processor'):
+ processors[f'{name}.processor'] = module.get_processor()
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f'{name}.{sub_name}', child,
+ processors)
+ return processors
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[
+ str, AttentionProcessor]]):
+ """
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f'A dict of processors was passed, but the number of processors {len(processor)} does not match the number of attention layers: {count}. Please make sure to pass {count} processor classes.'
+ )
+
+ def fn_recursive_attn_processor(name: str, module: paddle.nn.Layer,
+ processor):
+ if hasattr(module, 'set_processor'):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f'{name}.processor'))
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f'{name}.{sub_name}', child,
+ processor)
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+ for _, attn_processor in self.attn_processors.items():
+ if 'Added' in str(attn_processor.__class__.__name__):
+ raise ValueError(
+ '`fuse_qkv_projections()` is not supported for models having added KV projections.'
+ )
+ self.original_attn_processors = self.attn_processors
+ for module in self.sublayers():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
+
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def forward(
+ self,
+ hidden_states: paddle.Tensor,
+ encoder_hidden_states: paddle.Tensor,
+ timestep: Union[int, float, paddle.Tensor],
+ timestep_cond: Optional[paddle.Tensor]=None,
+ image_rotary_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]]=None,
+ return_dict: bool=True
+ ):
+ batch_size, num_frames, channels, height, width = hidden_states.shape
+
+ # 1. Time embedding
+ timesteps = timestep
+ t_emb = self.time_proj(timesteps)
+ t_emb = t_emb.cast(hidden_states.dtype)
+ emb = self.time_embedding(t_emb, timestep_cond)
+
+ # 2. Patch embedding
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
+ hidden_states = self.embedding_dropout(hidden_states)
+
+ text_seq_length = tuple(encoder_hidden_states.shape)[1]
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 3. Transformer blocks
+ for i, block in enumerate(self.transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+ raise NotImplementedError
+ else:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb
+ )
+ # print("hidden_states:", hidden_states.abs().mean().item(), hidden_states.min().item(), hidden_states.max().item())
+ # print("encoder_hidden_states:", encoder_hidden_states.abs().mean().item(), encoder_hidden_states.min().item(), encoder_hidden_states.max().item())
+
+ if not self.config.use_rotary_positional_embeddings:
+ # 2B
+ hidden_states = self.norm_final(hidden_states)
+ else:
+ # 5B
+ hidden_states = paddle.concat(x=[encoder_hidden_states,
+ hidden_states], axis=1)
+ hidden_states = self.norm_final(hidden_states)
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # 4. Final block
+ hidden_states = self.norm_out(hidden_states, temb=emb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. Unpatchify
+ p = self.config.patch_size
+ output = hidden_states.reshape([batch_size, num_frames, height // p, width // p, -1, p, p])
+ output = output.transpose(perm=[0, 1, 4, 2, 5, 3, 6]).flatten(5, 6).flatten(3, 4)
+ if not return_dict:
+ return output,
+ return Transformer2DModelOutput(sample=output)
diff --git a/ppdiffusers/ppdiffusers/models/downsampling.py b/ppdiffusers/ppdiffusers/models/downsampling.py
new file mode 100644
index 000000000..b99623ab1
--- /dev/null
+++ b/ppdiffusers/ppdiffusers/models/downsampling.py
@@ -0,0 +1,326 @@
+import paddle
+from typing import Optional, Tuple
+from .normalization import RMSNorm
+from .upsampling import upfirdn2d_native
+
+
+class Downsample1D(paddle.nn.Layer):
+ """A 1D downsampling layer with an optional convolution.
+
+ Parameters:
+ channels (`int`):
+ number of channels in the inputs and outputs.
+ use_conv (`bool`, default `False`):
+ option to use a convolution.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ padding (`int`, default `1`):
+ padding for the convolution.
+ name (`str`, default `conv`):
+ name of the downsampling 1D layer.
+ """
+
+ def __init__(self, channels: int, use_conv: bool=False, out_channels:
+ Optional[int]=None, padding: int=1, name: str='conv'):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+ if use_conv:
+ self.conv = paddle.nn.Conv1D(in_channels=self.channels,
+ out_channels=self.out_channels, kernel_size=3, stride=
+ stride, padding=padding)
+ else:
+ assert self.channels == self.out_channels
+ self.conv = paddle.nn.AvgPool1D(kernel_size=stride, stride=
+ stride, exclusive=False)
+
+ def forward(self, inputs: paddle.Tensor) ->paddle.Tensor:
+ assert tuple(inputs.shape)[1] == self.channels
+ return self.conv(inputs)
+
+
+class Downsample2D(paddle.nn.Layer):
+ """A 2D downsampling layer with an optional convolution.
+
+ Parameters:
+ channels (`int`):
+ number of channels in the inputs and outputs.
+ use_conv (`bool`, default `False`):
+ option to use a convolution.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ padding (`int`, default `1`):
+ padding for the convolution.
+ name (`str`, default `conv`):
+ name of the downsampling 2D layer.
+ """
+
+ def __init__(self, channels: int, use_conv: bool=False, out_channels:
+ Optional[int]=None, padding: int=1, name: str='conv', kernel_size=3,
+ norm_type=None, eps=None, elementwise_affine=None, bias=True):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+ if norm_type == 'ln_norm':
+ self.norm = paddle.nn.LayerNorm(normalized_shape=channels,
+ epsilon=eps, weight_attr=elementwise_affine, bias_attr=
+ elementwise_affine)
+ elif norm_type == 'rms_norm':
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
+ elif norm_type is None:
+ self.norm = None
+ else:
+ raise ValueError(f'unknown norm_type: {norm_type}')
+ if use_conv:
+ conv = paddle.nn.Conv2D(in_channels=self.channels, out_channels
+ =self.out_channels, kernel_size=kernel_size, stride=stride,
+ padding=padding, bias_attr=bias)
+ else:
+ assert self.channels == self.out_channels
+ conv = paddle.nn.AvgPool2D(kernel_size=stride, stride=stride,
+ exclusive=False)
+ if name == 'conv':
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == 'Conv2d_0':
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, hidden_states: paddle.Tensor, *args, **kwargs
+ ) ->paddle.Tensor:
+ if len(args) > 0 or kwargs.get('scale', None) is not None:
+ deprecation_message = (
+ 'The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`.'
+ )
+ print('scale', '1.0.0', deprecation_message)
+ assert tuple(hidden_states.shape)[1] == self.channels
+ if self.norm is not None:
+ hidden_states = self.norm(hidden_states.transpose(perm=[0, 2, 3,
+ 1])).transpose(perm=[0, 3, 1, 2])
+ if self.use_conv and self.padding == 0:
+ pad = 0, 1, 0, 1
+ hidden_states = paddle.nn.functional.pad(x=hidden_states, pad=
+ pad, mode='constant', value=0, pad_from_left_axis=False)
+ assert tuple(hidden_states.shape)[1] == self.channels
+ hidden_states = self.conv(hidden_states)
+ return hidden_states
+
+
+class FirDownsample2D(paddle.nn.Layer):
+ """A 2D FIR downsampling layer with an optional convolution.
+
+ Parameters:
+ channels (`int`):
+ number of channels in the inputs and outputs.
+ use_conv (`bool`, default `False`):
+ option to use a convolution.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
+ kernel for the FIR filter.
+ """
+
+ def __init__(self, channels: Optional[int]=None, out_channels: Optional
+ [int]=None, use_conv: bool=False, fir_kernel: Tuple[int, int, int,
+ int]=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = paddle.nn.Conv2D(in_channels=channels,
+ out_channels=out_channels, kernel_size=3, stride=1, padding=1)
+ self.fir_kernel = fir_kernel
+ self.use_conv = use_conv
+ self.out_channels = out_channels
+
+ def _downsample_2d(self, hidden_states: paddle.Tensor, weight: Optional
+ [paddle.Tensor]=None, kernel: Optional[paddle.Tensor]=None, factor:
+ int=2, gain: float=1) ->paddle.Tensor:
+ """Fused `Conv2d()` followed by `downsample_2d()`.
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
+ arbitrary order.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ weight (`torch.Tensor`, *optional*):
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
+ performed by `inChannels = x.shape[0] // numGroups`.
+ kernel (`torch.Tensor`, *optional*):
+ FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
+ corresponds to average pooling.
+ factor (`int`, *optional*, default to `2`):
+ Integer downsampling factor.
+ gain (`float`, *optional*, default to `1.0`):
+ Scaling factor for signal magnitude.
+
+ Returns:
+ output (`torch.Tensor`):
+ Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
+ datatype as `x`.
+ """
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+ kernel = paddle.to_tensor(data=kernel, dtype='float32')
+ if kernel.ndim == 1:
+ kernel = paddle.outer(x=kernel, y=kernel)
+ kernel /= paddle.sum(x=kernel)
+ kernel = kernel * gain
+ if self.use_conv:
+ _, _, convH, convW = tuple(weight.shape)
+ pad_value = tuple(kernel.shape)[0] - factor + (convW - 1)
+ stride_value = [factor, factor]
+ upfirdn_input = upfirdn2d_native(hidden_states, paddle.
+ to_tensor(data=kernel, place=hidden_states.place), pad=((
+ pad_value + 1) // 2, pad_value // 2))
+ output = paddle.nn.functional.conv2d(x=upfirdn_input, weight=
+ weight, stride=stride_value, padding=0)
+ else:
+ pad_value = tuple(kernel.shape)[0] - factor
+ output = upfirdn2d_native(hidden_states, paddle.to_tensor(data=
+ kernel, place=hidden_states.place), down=factor, pad=((
+ pad_value + 1) // 2, pad_value // 2))
+ return output
+
+ def forward(self, hidden_states: paddle.Tensor) ->paddle.Tensor:
+ if self.use_conv:
+ downsample_input = self._downsample_2d(hidden_states, weight=
+ self.Conv2d_0.weight, kernel=self.fir_kernel)
+ hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1,
+ -1, 1, 1)
+ else:
+ hidden_states = self._downsample_2d(hidden_states, kernel=self.
+ fir_kernel, factor=2)
+ return hidden_states
+
+
+class KDownsample2D(paddle.nn.Layer):
+ """A 2D K-downsampling layer.
+
+ Parameters:
+ pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
+ """
+
+ def __init__(self, pad_mode: str='reflect'):
+ super().__init__()
+ self.pad_mode = pad_mode
+ kernel_1d = paddle.to_tensor(data=[[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
+ self.pad = tuple(kernel_1d.shape)[1] // 2 - 1
+ self.register_buffer(name='kernel', tensor=kernel_1d.T @ kernel_1d,
+ persistable=False)
+
+ def forward(self, inputs: paddle.Tensor) ->paddle.Tensor:
+ inputs = paddle.nn.functional.pad(x=inputs, pad=(self.pad,) * 4,
+ mode=self.pad_mode, pad_from_left_axis=False)
+ weight = paddle.zeros(shape=[tuple(inputs.shape)[1], tuple(inputs.
+ shape)[1], tuple(self.kernel.shape)[0], tuple(self.kernel.shape
+ )[1]], dtype=inputs.dtype)
+ indices = paddle.arange(end=tuple(inputs.shape)[1])
+ kernel = self.kernel.to(weight)[None, :].expand(shape=[tuple(inputs
+ .shape)[1], -1, -1])
+ weight[indices, indices] = kernel
+ return paddle.nn.functional.conv2d(x=inputs, weight=weight, stride=2)
+
+
+class CogVideoXDownsample3D(paddle.nn.Layer):
+ """
+ A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
+
+ Args:
+ in_channels (`int`):
+ Number of channels in the input image.
+ out_channels (`int`):
+ Number of channels produced by the convolution.
+ kernel_size (`int`, defaults to `3`):
+ Size of the convolving kernel.
+ stride (`int`, defaults to `2`):
+ Stride of the convolution.
+ padding (`int`, defaults to `0`):
+ Padding added to all four sides of the input.
+ compress_time (`bool`, defaults to `False`):
+ Whether or not to compress the time dimension.
+ """
+
+ def __init__(self, in_channels: int, out_channels: int, kernel_size:
+ int=3, stride: int=2, padding: int=0, compress_time: bool=False):
+ super().__init__()
+ self.conv = paddle.nn.Conv2D(in_channels=in_channels, out_channels=
+ out_channels, kernel_size=kernel_size, stride=stride, padding=
+ padding)
+ self.compress_time = compress_time
+
+ def forward(self, x: paddle.Tensor) ->paddle.Tensor:
+ if self.compress_time:
+ batch_size, channels, frames, height, width = tuple(x.shape)
+ x = x.transpose(perm=[0, 3, 4, 1, 2]).reshape(batch_size *
+ height * width, channels, frames)
+ if tuple(x.shape)[-1] % 2 == 1:
+ x_first, x_rest = x[..., 0], x[..., 1:]
+ if tuple(x_rest.shape)[-1] > 0:
+ x_rest = paddle.nn.functional.avg_pool1d(kernel_size=2,
+ stride=2, x=x_rest, exclusive=False)
+ x = paddle.concat(x=[x_first[..., None], x_rest], axis=-1)
+ x = x.reshape(batch_size, height, width, channels, tuple(x.
+ shape)[-1]).transpose(perm=[0, 3, 4, 1, 2])
+ else:
+ x = paddle.nn.functional.avg_pool1d(kernel_size=2, stride=2,
+ x=x, exclusive=False)
+ x = x.reshape(batch_size, height, width, channels, tuple(x.
+ shape)[-1]).transpose(perm=[0, 3, 4, 1, 2])
+ pad = 0, 1, 0, 1
+ x = paddle.nn.functional.pad(x=x, pad=pad, mode='constant', value=0,
+ pad_from_left_axis=False)
+ batch_size, channels, frames, height, width = tuple(x.shape)
+ x = x.transpose(perm=[0, 2, 1, 3, 4]).reshape(batch_size * frames,
+ channels, height, width)
+ x = self.conv(x)
+ x = x.reshape(batch_size, frames, tuple(x.shape)[1], tuple(x.shape)
+ [2], tuple(x.shape)[3]).transpose(perm=[0, 2, 1, 3, 4])
+ return x
+
+
+def downsample_2d(hidden_states: paddle.Tensor, kernel: Optional[paddle.
+ Tensor]=None, factor: int=2, gain: float=1) ->paddle.Tensor:
+ """Downsample2D a batch of 2D images with the given filter.
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
+ shape is a multiple of the downsampling factor.
+
+ Args:
+ hidden_states (`torch.Tensor`)
+ Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ kernel (`torch.Tensor`, *optional*):
+ FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
+ corresponds to average pooling.
+ factor (`int`, *optional*, default to `2`):
+ Integer downsampling factor.
+ gain (`float`, *optional*, default to `1.0`):
+ Scaling factor for signal magnitude.
+
+ Returns:
+ output (`torch.Tensor`):
+ Tensor of the shape `[N, C, H // factor, W // factor]`
+ """
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+ kernel = paddle.to_tensor(data=kernel, dtype='float32')
+ if kernel.ndim == 1:
+ kernel = paddle.outer(x=kernel, y=kernel)
+ kernel /= paddle.sum(x=kernel)
+ kernel = kernel * gain
+ pad_value = tuple(kernel.shape)[0] - factor
+ output = upfirdn2d_native(hidden_states, kernel.to(device=hidden_states
+ .place), down=factor, pad=((pad_value + 1) // 2, pad_value // 2))
+ return output
diff --git a/ppdiffusers/ppdiffusers/models/embeddings.py b/ppdiffusers/ppdiffusers/models/embeddings.py
index d339a3922..f80be4f16 100644
--- a/ppdiffusers/ppdiffusers/models/embeddings.py
+++ b/ppdiffusers/ppdiffusers/models/embeddings.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
-from typing import Optional
+from typing import List, Optional, Tuple, Union
import numpy as np
import paddle
@@ -913,3 +913,225 @@ def forward(self, caption):
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
+
+
+def get_3d_sincos_pos_embed(embed_dim: int, spatial_size: Union[int, Tuple[
+ int, int]], temporal_size: int, spatial_interpolation_scale: float=1.0,
+ temporal_interpolation_scale: float=1.0) ->np.ndarray:
+ """
+ Args:
+ embed_dim (`int`):
+ spatial_size (`int` or `Tuple[int, int]`):
+ temporal_size (`int`):
+ spatial_interpolation_scale (`float`, defaults to 1.0):
+ temporal_interpolation_scale (`float`, defaults to 1.0):
+ """
+ if embed_dim % 4 != 0:
+ raise ValueError('`embed_dim` must be divisible by 4')
+ if isinstance(spatial_size, int):
+ spatial_size = spatial_size, spatial_size
+ embed_dim_spatial = 3 * embed_dim // 4
+ embed_dim_temporal = embed_dim // 4
+ grid_h = np.arange(spatial_size[1], dtype=np.float32
+ ) / spatial_interpolation_scale
+ grid_w = np.arange(spatial_size[0], dtype=np.float32
+ ) / spatial_interpolation_scale
+ grid = np.meshgrid(grid_w, grid_h)
+ grid = np.stack(grid, axis=0)
+ grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial,
+ grid)
+ grid_t = np.arange(temporal_size, dtype=np.float32
+ ) / temporal_interpolation_scale
+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal,
+ grid_t)
+ pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
+ pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0)
+ pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
+ pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] *
+ spatial_size[1], axis=1)
+ pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1
+ )
+ return pos_embed
+
+class CogVideoXPatchEmbed(paddle.nn.Layer):
+
+ def __init__(
+ self,
+ patch_size: int=2,
+ in_channels: int=16,
+ embed_dim: int=1920,
+ text_embed_dim: int=4096,
+ bias: bool=True,
+ sample_width: int=90,
+ sample_height: int=60,
+ sample_frames: int=49,
+ temporal_compression_ratio: int=4,
+ max_text_seq_length: int=226,
+ spatial_interpolation_scale: float=1.875,
+ temporal_interpolation_scale: float=1.0,
+ use_positional_embeddings: bool=True,
+ use_learned_positional_embeddings: bool=True
+ ) ->None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.embed_dim = embed_dim
+ self.sample_height = sample_height
+ self.sample_width = sample_width
+ self.sample_frames = sample_frames
+ self.temporal_compression_ratio = temporal_compression_ratio
+ self.max_text_seq_length = max_text_seq_length
+ self.spatial_interpolation_scale = spatial_interpolation_scale
+ self.temporal_interpolation_scale = temporal_interpolation_scale
+ self.use_positional_embeddings = use_positional_embeddings
+ self.use_learned_positional_embeddings = use_learned_positional_embeddings
+
+ self.proj = paddle.nn.Conv2D(in_channels=in_channels, out_channels=
+ embed_dim, kernel_size=(patch_size, patch_size), stride=
+ patch_size, bias_attr=bias)
+
+ self.text_proj = paddle.nn.Linear(in_features=text_embed_dim,
+ out_features=embed_dim)
+
+ if use_positional_embeddings or use_learned_positional_embeddings:
+ persistent = use_learned_positional_embeddings
+ pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
+ self.register_buffer(name='pos_embedding', tensor=pos_embedding,
+ persistable=persistent)
+
+ def _get_positional_embeddings(self, sample_height: int, sample_width:
+ int, sample_frames: int) ->paddle.Tensor:
+ post_patch_height = sample_height // self.patch_size
+ post_patch_width = sample_width // self.patch_size
+ post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
+ num_patches = (post_patch_height * post_patch_width *post_time_compression_frames)
+
+ pos_embedding = get_3d_sincos_pos_embed(
+ self.embed_dim, (
+ post_patch_width, post_patch_height),
+ post_time_compression_frames, self.spatial_interpolation_scale,
+ self.temporal_interpolation_scale
+ )
+ pos_embedding = paddle.to_tensor(data=pos_embedding).flatten(start_axis=0, stop_axis=1)
+ joint_pos_embedding = paddle.zeros([1, self.max_text_seq_length + num_patches, self.embed_dim])
+ joint_pos_embedding[0, self.max_text_seq_length :] = pos_embedding
+ return joint_pos_embedding
+
+ def forward(self, text_embeds: paddle.Tensor, image_embeds: paddle.Tensor):
+ """
+ Args:
+ text_embeds (`torch.Tensor`):
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
+ image_embeds (`torch.Tensor`):
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
+ """
+ text_embeds = self.text_proj(text_embeds)
+ # import numpy as np
+ # text_embeds = paddle.to_tensor(np.load("../CogVideo/inference/text_embeds.npy"), dtype=paddle.float32)
+
+ batch, num_frames, channels, height, width = image_embeds.shape
+ image_embeds = image_embeds.reshape([-1, channels, height, width])
+ image_embeds = self.proj(image_embeds)
+ image_embeds = image_embeds.reshape([batch, num_frames] + image_embeds.shape[1:])
+ image_embeds = image_embeds.flatten(3).transpose([0, 1, 3, 2]) # [batch, num_frames, height x width, channels]
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
+
+ embeds = paddle.concat(x=[text_embeds, image_embeds], axis=1
+ ).contiguous()
+ if (self.use_positional_embeddings or self.use_learned_positional_embeddings):
+ if self.use_learned_positional_embeddings and (self.
+ sample_width != width or self.sample_height != height):
+ raise ValueError(
+ "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'.If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
+ )
+ pre_time_compression_frames = (num_frames - 1
+ ) * self.temporal_compression_ratio + 1
+ if (self.sample_height != height or self.sample_width != width or
+ self.sample_frames != pre_time_compression_frames):
+ pos_embedding = self._get_positional_embeddings(height,
+ width, pre_time_compression_frames)
+ pos_embedding = pos_embedding.cast(embeds.dtype)
+ else:
+ pos_embedding = self.pos_embedding
+ embeds = embeds + pos_embedding
+ return embeds
+
+def get_3d_rotary_pos_embed(embed_dim, crops_coords, grid_size,
+ temporal_size, theta: int=10000, use_real: bool=True) ->Union[paddle.
+ Tensor, Tuple[paddle.Tensor, paddle.Tensor]]:
+ """
+ RoPE for video tokens with 3D structure.
+
+ Args:
+ embed_dim: (`int`):
+ The embedding dimension size, corresponding to hidden_size_head.
+ crops_coords (`Tuple[int]`):
+ The top-left and bottom-right coordinates of the crop.
+ grid_size (`Tuple[int]`):
+ The grid size of the spatial positional embedding (height, width).
+ temporal_size (`int`):
+ The size of the temporal dimension.
+ theta (`float`):
+ Scaling factor for frequency computation.
+ use_real (`bool`):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+
+ Returns:
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
+ """
+ start, stop = crops_coords
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False,
+ dtype=np.float32)
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False,
+ dtype=np.float32)
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False,
+ dtype=np.float32)
+ dim_t = embed_dim // 4
+ dim_h = embed_dim // 8 * 3
+ dim_w = embed_dim // 8 * 3
+ freqs_t = 1.0 / theta ** (paddle.arange(start=0, end=dim_t, step=2).
+ astype(dtype='float32') / dim_t)
+ grid_t = paddle.to_tensor(data=grid_t).astype(dtype='float32')
+ freqs_t = paddle.einsum('n , f -> n f', grid_t, freqs_t)
+ freqs_t = freqs_t.repeat_interleave(repeats=2, axis=-1)
+ freqs_h = 1.0 / theta ** (paddle.arange(start=0, end=dim_h, step=2).
+ astype(dtype='float32') / dim_h)
+ freqs_w = 1.0 / theta ** (paddle.arange(start=0, end=dim_w, step=2).
+ astype(dtype='float32') / dim_w)
+ grid_h = paddle.to_tensor(data=grid_h).astype(dtype='float32')
+ grid_w = paddle.to_tensor(data=grid_w).astype(dtype='float32')
+ freqs_h = paddle.einsum('n , f -> n f', grid_h, freqs_h)
+ freqs_w = paddle.einsum('n , f -> n f', grid_w, freqs_w)
+ freqs_h = freqs_h.repeat_interleave(repeats=2, axis=-1)
+ freqs_w = freqs_w.repeat_interleave(repeats=2, axis=-1)
+
+ def broadcast(tensors, dim=-1):
+ num_tensors = len(tensors)
+ shape_lens = {len(tuple(t.shape)) for t in tensors}
+ assert len(shape_lens
+ ) == 1, 'tensors must all have the same number of dimensions'
+ shape_len = list(shape_lens)[0]
+ dim = dim + shape_len if dim < 0 else dim
+ dims = list(zip(*(list(tuple(t.shape)) for t in tensors)))
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
+ assert all([*(len(set(t[1])) <= 2 for t in expandable_dims)]
+ ), 'invalid dimensions for broadcastable concatenation'
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
+ expanded_dims.insert(dim, (dim, dims[dim]))
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
+ tensors = [t[0].expand(shape=t[1]) for t in zip(tensors,
+ expandable_shapes)]
+ return paddle.concat(x=tensors, axis=dim)
+ freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :],
+ freqs_w[None, None, :, :]), dim=-1)
+ t, h, w, d = tuple(freqs.shape)
+ freqs = freqs.view(t * h * w, d)
+ sin = freqs.sin()
+ cos = freqs.cos()
+ if use_real:
+ return cos, sin
+ else:
+ freqs_cis = paddle.complex(paddle.ones_like(x=freqs) * paddle.cos(
+ freqs), paddle.ones_like(x=freqs) * paddle.sin(freqs))
+ return freqs_cis
\ No newline at end of file
diff --git a/ppdiffusers/ppdiffusers/models/normalization.py b/ppdiffusers/ppdiffusers/models/normalization.py
index 68a732cc0..e582905f6 100644
--- a/ppdiffusers/ppdiffusers/models/normalization.py
+++ b/ppdiffusers/ppdiffusers/models/normalization.py
@@ -32,17 +32,45 @@ class AdaLayerNorm(nn.Layer):
num_embeddings (`int`): The size of the embeddings dictionary.
"""
- def __init__(self, embedding_dim: int, num_embeddings: int):
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_embeddings: Optional[int] = None,
+ output_dim: Optional[int] = None,
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-5,
+ chunk_dim: int = 0,
+ ):
super().__init__()
- self.emb = nn.Embedding(num_embeddings, embedding_dim)
+
+ self.chunk_dim = chunk_dim
+ output_dim = output_dim or embedding_dim * 2
+
+ if num_embeddings is not None:
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
+ else:
+ self.emb = None
+
self.silu = nn.Silu()
- self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
- norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False)
- self.norm = nn.LayerNorm(embedding_dim, **norm_elementwise_affine_kwargs)
+ self.linear = nn.Linear(embedding_dim, output_dim)
+ if norm_elementwise_affine:
+ norm_elementwise_affine_kwargs = dict(weight_attr=None, bias_attr=None)
+ else:
+ norm_elementwise_affine_kwargs = dict(weight_attr=False, bias_attr=False)
+ self.norm = nn.LayerNorm(output_dim // 2, epsilon=norm_eps, **norm_elementwise_affine_kwargs)
- def forward(self, x: paddle.Tensor, timestep: paddle.Tensor) -> paddle.Tensor:
- emb = self.linear(self.silu(self.emb(timestep)))
- scale, shift = paddle.chunk(emb, 2)
+ def forward(self, x: paddle.Tensor, timestep: Optional[paddle.Tensor] = None, temb: Optional[paddle.Tensor] = None) -> paddle.Tensor:
+ if self.emb is not None:
+ temb = self.emb(timestep)
+ temb = self.linear(self.silu(temb))
+ if self.chunk_dim == 1:
+ # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
+ # other if-branch. This branch is specific to CogVideoX for now.
+ shift, scale = paddle.chunk(temb, 2, axis=1)
+ shift = shift[:, None, :]
+ scale = scale[:, None, :]
+ else:
+ scale, shift = paddle.chunk(temb, 2)
x = self.norm(x) * (1 + scale) + shift
return x
@@ -224,3 +252,29 @@ def forward(self, hidden_states):
epsilon=self.epsilon,
begin_norm_axis=2,
)
+
+
+class CogVideoXLayerNormZero(paddle.nn.Layer):
+
+ def __init__(self, conditioning_dim: int, embedding_dim: int,
+ elementwise_affine: bool=True, eps: float=1e-05, bias: bool=True
+ ) ->None:
+ super().__init__()
+ self.silu = paddle.nn.Silu()
+ self.linear = paddle.nn.Linear(in_features=conditioning_dim,
+ out_features=6 * embedding_dim, bias_attr=bias)
+ self.norm = paddle.nn.LayerNorm(normalized_shape=embedding_dim,
+ epsilon=eps, weight_attr=elementwise_affine, bias_attr=
+ elementwise_affine)
+
+ def forward(self, hidden_states: paddle.Tensor, encoder_hidden_states:
+ paddle.Tensor, temb: paddle.Tensor) ->Tuple[paddle.Tensor, paddle.
+ Tensor]:
+ shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self
+ .silu(temb)).chunk(chunks=6, axis=1)
+ hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :
+ ] + shift[:, None, :]
+ encoder_hidden_states = self.norm(encoder_hidden_states) * (1 +
+ enc_scale)[:, None, :] + enc_shift[:, None, :]
+ return hidden_states, encoder_hidden_states, gate[:, None, :
+ ], enc_gate[:, None, :]
diff --git a/ppdiffusers/ppdiffusers/models/upsampling.py b/ppdiffusers/ppdiffusers/models/upsampling.py
new file mode 100644
index 000000000..f9ac22d6b
--- /dev/null
+++ b/ppdiffusers/ppdiffusers/models/upsampling.py
@@ -0,0 +1,401 @@
+import paddle
+from typing import Optional, Tuple
+from .normalization import RMSNorm
+
+
+class Upsample1D(paddle.nn.Layer):
+ """A 1D upsampling layer with an optional convolution.
+
+ Parameters:
+ channels (`int`):
+ number of channels in the inputs and outputs.
+ use_conv (`bool`, default `False`):
+ option to use a convolution.
+ use_conv_transpose (`bool`, default `False`):
+ option to use a convolution transpose.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ name (`str`, default `conv`):
+ name of the upsampling 1D layer.
+ """
+
+ def __init__(self, channels: int, use_conv: bool=False,
+ use_conv_transpose: bool=False, out_channels: Optional[int]=None,
+ name: str='conv'):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+ self.conv = None
+ if use_conv_transpose:
+ self.conv = paddle.nn.Conv1DTranspose(in_channels=channels,
+ out_channels=self.out_channels, kernel_size=4, stride=2,
+ padding=1)
+ elif use_conv:
+ self.conv = paddle.nn.Conv1D(in_channels=self.channels,
+ out_channels=self.out_channels, kernel_size=3, padding=1)
+
+ def forward(self, inputs: paddle.Tensor) ->paddle.Tensor:
+ assert tuple(inputs.shape)[1] == self.channels
+ if self.use_conv_transpose:
+ return self.conv(inputs)
+ outputs = paddle.nn.functional.interpolate(x=inputs, scale_factor=
+ 2.0, mode='nearest')
+ if self.use_conv:
+ outputs = self.conv(outputs)
+ return outputs
+
+
+class Upsample2D(paddle.nn.Layer):
+ """A 2D upsampling layer with an optional convolution.
+
+ Parameters:
+ channels (`int`):
+ number of channels in the inputs and outputs.
+ use_conv (`bool`, default `False`):
+ option to use a convolution.
+ use_conv_transpose (`bool`, default `False`):
+ option to use a convolution transpose.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ name (`str`, default `conv`):
+ name of the upsampling 2D layer.
+ """
+
+ def __init__(self, channels: int, use_conv: bool=False,
+ use_conv_transpose: bool=False, out_channels: Optional[int]=None,
+ name: str='conv', kernel_size: Optional[int]=None, padding=1,
+ norm_type=None, eps=None, elementwise_affine=None, bias=True,
+ interpolate=True):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+ self.interpolate = interpolate
+ if norm_type == 'ln_norm':
+ self.norm = paddle.nn.LayerNorm(normalized_shape=channels,
+ epsilon=eps, weight_attr=elementwise_affine, bias_attr=
+ elementwise_affine)
+ elif norm_type == 'rms_norm':
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
+ elif norm_type is None:
+ self.norm = None
+ else:
+ raise ValueError(f'unknown norm_type: {norm_type}')
+ conv = None
+ if use_conv_transpose:
+ if kernel_size is None:
+ kernel_size = 4
+ conv = paddle.nn.Conv2DTranspose(in_channels=channels,
+ out_channels=self.out_channels, kernel_size=kernel_size,
+ stride=2, padding=padding, bias_attr=bias)
+ elif use_conv:
+ if kernel_size is None:
+ kernel_size = 3
+ conv = paddle.nn.Conv2D(in_channels=self.channels, out_channels
+ =self.out_channels, kernel_size=kernel_size, padding=
+ padding, bias_attr=bias)
+ if name == 'conv':
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, hidden_states: paddle.Tensor, output_size: Optional[
+ int]=None, *args, **kwargs) ->paddle.Tensor:
+ if len(args) > 0 or kwargs.get('scale', None) is not None:
+ deprecation_message = (
+ 'The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`.'
+ )
+ print('scale', '1.0.0', deprecation_message)
+ assert tuple(hidden_states.shape)[1] == self.channels
+ if self.norm is not None:
+ hidden_states = self.norm(hidden_states.transpose(perm=[0, 2, 3,
+ 1])).transpose(perm=[0, 3, 1, 2])
+ if self.use_conv_transpose:
+ return self.conv(hidden_states)
+ dtype = hidden_states.dtype
+ if dtype == 'bfloat16':
+ hidden_states = hidden_states.cast('float32')
+ if tuple(hidden_states.shape)[0] >= 64:
+ hidden_states = hidden_states.contiguous()
+ if self.interpolate:
+ if output_size is None:
+ hidden_states = paddle.nn.functional.interpolate(x=
+ hidden_states, scale_factor=2.0, mode='nearest')
+ else:
+ hidden_states = paddle.nn.functional.interpolate(x=
+ hidden_states, size=output_size, mode='nearest')
+ if dtype == 'bfloat16':
+ hidden_states = hidden_states.cast(dtype)
+ if self.use_conv:
+ if self.name == 'conv':
+ hidden_states = self.conv(hidden_states)
+ else:
+ hidden_states = self.Conv2d_0(hidden_states)
+ return hidden_states
+
+
+class FirUpsample2D(paddle.nn.Layer):
+ """A 2D FIR upsampling layer with an optional convolution.
+
+ Parameters:
+ channels (`int`, optional):
+ number of channels in the inputs and outputs.
+ use_conv (`bool`, default `False`):
+ option to use a convolution.
+ out_channels (`int`, optional):
+ number of output channels. Defaults to `channels`.
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
+ kernel for the FIR filter.
+ """
+
+ def __init__(self, channels: Optional[int]=None, out_channels: Optional
+ [int]=None, use_conv: bool=False, fir_kernel: Tuple[int, int, int,
+ int]=(1, 3, 3, 1)):
+ super().__init__()
+ out_channels = out_channels if out_channels else channels
+ if use_conv:
+ self.Conv2d_0 = paddle.nn.Conv2D(in_channels=channels,
+ out_channels=out_channels, kernel_size=3, stride=1, padding=1)
+ self.use_conv = use_conv
+ self.fir_kernel = fir_kernel
+ self.out_channels = out_channels
+
+ def _upsample_2d(self, hidden_states: paddle.Tensor, weight: Optional[
+ paddle.Tensor]=None, kernel: Optional[paddle.Tensor]=None, factor:
+ int=2, gain: float=1) ->paddle.Tensor:
+ """Fused `upsample_2d()` followed by `Conv2d()`.
+
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
+ arbitrary order.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ weight (`torch.Tensor`, *optional*):
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
+ performed by `inChannels = x.shape[0] // numGroups`.
+ kernel (`torch.Tensor`, *optional*):
+ FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
+ corresponds to nearest-neighbor upsampling.
+ factor (`int`, *optional*): Integer upsampling factor (default: 2).
+ gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ output (`torch.Tensor`):
+ Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
+ datatype as `hidden_states`.
+ """
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+ kernel = paddle.to_tensor(data=kernel, dtype='float32')
+ if kernel.ndim == 1:
+ kernel = paddle.outer(x=kernel, y=kernel)
+ kernel /= paddle.sum(x=kernel)
+ kernel = kernel * (gain * factor ** 2)
+ if self.use_conv:
+ convH = tuple(weight.shape)[2]
+ convW = tuple(weight.shape)[3]
+ inC = tuple(weight.shape)[1]
+ pad_value = tuple(kernel.shape)[0] - factor - (convW - 1)
+ stride = factor, factor
+ output_shape = (tuple(hidden_states.shape)[2] - 1
+ ) * factor + convH, (tuple(hidden_states.shape)[3] - 1
+ ) * factor + convW
+ output_padding = output_shape[0] - (tuple(hidden_states.shape)[
+ 2] - 1) * stride[0] - convH, output_shape[1] - (tuple(
+ hidden_states.shape)[3] - 1) * stride[1] - convW
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
+ num_groups = tuple(hidden_states.shape)[1] // inC
+ weight = paddle.reshape(x=weight, shape=(num_groups, -1, inC,
+ convH, convW))
+ weight = paddle.flip(x=weight, axis=[3, 4]).transpose(perm=[0,
+ 2, 1, 3, 4])
+ weight = paddle.reshape(x=weight, shape=(num_groups * inC, -1,
+ convH, convW))
+ inverse_conv = paddle.nn.functional.conv2d_transpose(x=
+ hidden_states, weight=weight, stride=stride, output_padding
+ =output_padding, padding=0)
+ output = upfirdn2d_native(inverse_conv, paddle.to_tensor(data=
+ kernel), pad=((pad_value + 1) //
+ 2 + factor - 1, pad_value // 2 + 1))
+ else:
+ pad_value = tuple(kernel.shape)[0] - factor
+ output = upfirdn2d_native(hidden_states, paddle.to_tensor(data=
+ kernel), up=factor, pad=((
+ pad_value + 1) // 2 + factor - 1, pad_value // 2))
+ return output
+
+ def forward(self, hidden_states: paddle.Tensor) ->paddle.Tensor:
+ if self.use_conv:
+ height = self._upsample_2d(hidden_states, self.Conv2d_0.weight,
+ kernel=self.fir_kernel)
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
+ else:
+ height = self._upsample_2d(hidden_states, kernel=self.
+ fir_kernel, factor=2)
+ return height
+
+
+class KUpsample2D(paddle.nn.Layer):
+ """A 2D K-upsampling layer.
+
+ Parameters:
+ pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
+ """
+
+ def __init__(self, pad_mode: str='reflect'):
+ super().__init__()
+ self.pad_mode = pad_mode
+ kernel_1d = paddle.to_tensor(data=[[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
+ self.pad = tuple(kernel_1d.shape)[1] // 2 - 1
+ self.register_buffer(name='kernel', tensor=kernel_1d.T @ kernel_1d,
+ persistable=False)
+
+ def forward(self, inputs: paddle.Tensor) ->paddle.Tensor:
+ inputs = paddle.nn.functional.pad(x=inputs, pad=((self.pad + 1) //
+ 2,) * 4, mode=self.pad_mode, pad_from_left_axis=False)
+ weight = paddle.zeros(shape=[tuple(inputs.shape)[1], tuple(inputs.
+ shape)[1], tuple(self.kernel.shape)[0], tuple(self.kernel.shape
+ )[1]], dtype=inputs.dtype)
+ indices = paddle.arange(end=tuple(inputs.shape)[1])
+ kernel = self.kernel.cast(weight.dtype)[None, :].expand(shape=[tuple(inputs
+ .shape)[1], -1, -1])
+ weight[indices, indices] = kernel
+ return paddle.nn.functional.conv2d_transpose(x=inputs, weight=
+ weight, stride=2, padding=self.pad * 2 + 1)
+
+
+class CogVideoXUpsample3D(paddle.nn.Layer):
+ """
+ A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
+
+ Args:
+ in_channels (`int`):
+ Number of channels in the input image.
+ out_channels (`int`):
+ Number of channels produced by the convolution.
+ kernel_size (`int`, defaults to `3`):
+ Size of the convolving kernel.
+ stride (`int`, defaults to `1`):
+ Stride of the convolution.
+ padding (`int`, defaults to `1`):
+ Padding added to all four sides of the input.
+ compress_time (`bool`, defaults to `False`):
+ Whether or not to compress the time dimension.
+ """
+
+ def __init__(self, in_channels: int, out_channels: int, kernel_size:
+ int=3, stride: int=1, padding: int=1, compress_time: bool=False
+ ) ->None:
+ super().__init__()
+ self.conv = paddle.nn.Conv2D(in_channels=in_channels, out_channels=
+ out_channels, kernel_size=kernel_size, stride=stride, padding=
+ padding)
+ self.compress_time = compress_time
+
+ def forward(self, inputs: paddle.Tensor) ->paddle.Tensor:
+ if self.compress_time:
+ if tuple(inputs.shape)[2] > 1 and tuple(inputs.shape)[2] % 2 == 1:
+ x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
+ x_first = paddle.nn.functional.interpolate(x=x_first,
+ scale_factor=2.0)
+ x_rest = paddle.nn.functional.interpolate(x=x_rest,
+ scale_factor=2.0)
+ x_first = x_first[:, :, None, :, :]
+ inputs = paddle.concat(x=[x_first, x_rest], axis=2)
+ elif tuple(inputs.shape)[2] > 1:
+ inputs = paddle.nn.functional.interpolate(x=inputs,
+ scale_factor=2.0)
+ else:
+ inputs = inputs.squeeze(axis=2)
+ inputs = paddle.nn.functional.interpolate(x=inputs,
+ scale_factor=2.0)
+ inputs = inputs[:, :, None, :, :]
+ else:
+ b, c, t, h, w = tuple(inputs.shape)
+ inputs = inputs.transpose(perm=[0, 2, 1, 3, 4]).reshape([b * t,
+ c, h, w])
+ inputs = paddle.nn.functional.interpolate(x=inputs,
+ scale_factor=2.0)
+ inputs = inputs.reshape([b, t, c] + inputs.shape[2:]).transpose(perm=[0, 2, 1, 3, 4])
+ b, c, t, h, w = inputs.shape
+ inputs = inputs.transpose(perm=[0, 2, 1, 3, 4]).reshape([b * t, c, h, w])
+ inputs = self.conv(inputs)
+ inputs = inputs.reshape([b, t] + inputs.shape[1:]).transpose(perm
+ =[0, 2, 1, 3, 4])
+ return inputs
+
+
+def upfirdn2d_native(tensor: paddle.Tensor, kernel: paddle.Tensor, up: int=
+ 1, down: int=1, pad: Tuple[int, int]=(0, 0)) ->paddle.Tensor:
+ up_x = up_y = up
+ down_x = down_y = down
+ pad_x0 = pad_y0 = pad[0]
+ pad_x1 = pad_y1 = pad[1]
+ _, channel, in_h, in_w = tuple(tensor.shape)
+ tensor = tensor.reshape(-1, in_h, in_w, 1)
+ _, in_h, in_w, minor = tuple(tensor.shape)
+ kernel_h, kernel_w = tuple(kernel.shape)
+ out = tensor.view(-1, in_h, 1, in_w, 1, minor)
+ out = paddle.nn.functional.pad(x=out, pad=[0, 0, 0, up_x - 1, 0, 0, 0,
+ up_y - 1], pad_from_left_axis=False)
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+ out = paddle.nn.functional.pad(x=out, pad=[0, 0, max(pad_x0, 0), max(
+ pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)], pad_from_left_axis=False)
+ out = out[:, max(-pad_y0, 0):tuple(out.shape)[1] - max(-pad_y1, 0), max
+ (-pad_x0, 0):tuple(out.shape)[2] - max(-pad_x1, 0), :]
+ out = out.transpose(perm=[0, 3, 1, 2])
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x +
+ pad_x0 + pad_x1])
+ w = paddle.flip(x=kernel, axis=[0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = paddle.nn.functional.conv2d(x=out, weight=w)
+ out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h +
+ 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1)
+ out = out.transpose(perm=[0, 2, 3, 1])
+ out = out[:, ::down_y, ::down_x, :]
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ return out.view(-1, channel, out_h, out_w)
+
+
+def upsample_2d(hidden_states: paddle.Tensor, kernel: Optional[paddle.
+ Tensor]=None, factor: int=2, gain: float=1) ->paddle.Tensor:
+ """Upsample2D a batch of 2D images with the given filter.
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
+ a: multiple of the upsampling factor.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
+ kernel (`torch.Tensor`, *optional*):
+ FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
+ corresponds to nearest-neighbor upsampling.
+ factor (`int`, *optional*, default to `2`):
+ Integer upsampling factor.
+ gain (`float`, *optional*, default to `1.0`):
+ Scaling factor for signal magnitude (default: 1.0).
+
+ Returns:
+ output (`torch.Tensor`):
+ Tensor of the shape `[N, C, H * factor, W * factor]`
+ """
+ assert isinstance(factor, int) and factor >= 1
+ if kernel is None:
+ kernel = [1] * factor
+ kernel = paddle.to_tensor(data=kernel, dtype='float32')
+ if kernel.ndim == 1:
+ kernel = paddle.outer(x=kernel, y=kernel)
+ kernel /= paddle.sum(x=kernel)
+ kernel = kernel * (gain * factor ** 2)
+ pad_value = tuple(kernel.shape)[0] - factor
+ output = upfirdn2d_native(hidden_states, kernel, up=factor, pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2))
+ return output
diff --git a/ppdiffusers/ppdiffusers/pipelines/__init__.py b/ppdiffusers/ppdiffusers/pipelines/__init__.py
index b909c884d..d0a858e6c 100644
--- a/ppdiffusers/ppdiffusers/pipelines/__init__.py
+++ b/ppdiffusers/ppdiffusers/pipelines/__init__.py
@@ -101,6 +101,7 @@
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
+ _import_structure["cogvideo"] = ["CogVideoXPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
@@ -375,6 +376,7 @@
AudioLDM2UNet2DConditionModel,
)
from .blip_diffusion import BlipDiffusionPipeline
+ from .cogvideo import CogVideoXPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
PaddleInferStableDiffusionControlNetPipeline,
diff --git a/ppdiffusers/ppdiffusers/pipelines/cogvideo/__init__.py b/ppdiffusers/ppdiffusers/pipelines/cogvideo/__init__.py
new file mode 100644
index 000000000..8f333d821
--- /dev/null
+++ b/ppdiffusers/ppdiffusers/pipelines/cogvideo/__init__.py
@@ -0,0 +1,54 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ PPDIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_paddle_available,
+ is_paddlenlp_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_paddlenlp_available() and is_paddle_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_paddle_and_paddlenlp_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_paddle_and_paddlenlp_objects))
+else:
+ _import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
+ # _import_structure["pipeline_cogvideox_fun_control"] = ["CogVideoXFunControlPipeline"]
+ # _import_structure["pipeline_cogvideox_image2video"] = ["CogVideoXImageToVideoPipeline"]
+ # _import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"]
+
+if TYPE_CHECKING or PPDIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_paddlenlp_available() and is_paddle_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_paddle_and_paddlenlp_objects import *
+ else:
+ from .pipeline_cogvideox import CogVideoXPipeline
+ # from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline
+ # from .pipeline_cogvideox_image2video import CogVideoXImageToVideoPipeline
+ # from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
\ No newline at end of file
diff --git a/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox.py b/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox.py
new file mode 100644
index 000000000..2e472c71f
--- /dev/null
+++ b/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_cogvideox.py
@@ -0,0 +1,718 @@
+# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import paddle
+# from paddlenlp.transformers import T5EncoderModel, T5Tokenizer
+from ppdiffusers.transformers import T5EncoderModel, T5Tokenizer
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
+from ...models.embeddings import get_3d_rotary_pos_embed
+from ..pipeline_utils import DiffusionPipeline
+from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
+from ...utils import logging, replace_example_docstring
+from ...utils.paddle_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from .pipeline_output import CogVideoXPipelineOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> import torch
+ >>> from diffusers import CogVideoXPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
+ >>> prompt = (
+ ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
+ ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
+ ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
+ ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
+ ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
+ ... "atmosphere of this unique musical performance."
+ ... )
+ >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
+ >>> export_to_video(video, "output.mp4", fps=8)
+ ```
+"""
+
+
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class CogVideoXPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-video generation using CogVideoX.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. CogVideoX uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`CogVideoXTransformer3DModel`]):
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ _optional_components = []
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+
+ def __init__(
+ self,
+ tokenizer: T5Tokenizer,
+ text_encoder: T5EncoderModel,
+ vae: AutoencoderKLCogVideoX,
+ transformer: CogVideoXTransformer3DModel,
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+ self.vae_scale_factor_spatial = (
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ )
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
+ )
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ dtype: Optional[paddle.dtype] = None,
+ ):
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pd",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pd").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder(text_input_ids)[0]
+ prompt_embeds = prompt_embeds.cast(dtype)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.tile([1, num_videos_per_prompt, 1])
+ prompt_embeds = prompt_embeds.reshape([batch_size * num_videos_per_prompt, seq_len, -1])
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[paddle.Tensor] = None,
+ negative_prompt_embeds: Optional[paddle.Tensor] = None,
+ max_sequence_length: int = 226,
+ dtype: Optional[paddle.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def prepare_latents(
+ self, batch_size, num_channels_latents, num_frames, height, width, dtype, generator, latents=None
+ ):
+ shape = (
+ batch_size,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ num_channels_latents,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, dtype=dtype)
+ else:
+ latents = latents
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def decode_latents(self, latents: paddle.Tensor) -> paddle.Tensor:
+ latents = latents.transpose([0, 2, 1, 3, 4]) # [batch_size, num_channels, num_frames, height, width]
+ latents = 1 / self.vae.config.scaling_factor * latents
+
+ frames = self.vae.decode(latents).sample
+ return frames
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def fuse_qkv_projections(self) -> None:
+ r"""Enables fused QKV projections."""
+ self.fusing_transformer = True
+ self.transformer.fuse_qkv_projections()
+
+ def unfuse_qkv_projections(self) -> None:
+ r"""Disable QKV projection fusion if enabled."""
+ if not self.fusing_transformer:
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.transformer.unfuse_qkv_projections()
+ self.fusing_transformer = False
+
+ def _prepare_rotary_positional_embeddings(
+ self,
+ height: int,
+ width: int,
+ num_frames: int,
+ ) -> Tuple[paddle.Tensor, paddle.Tensor]:
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+
+ grid_crops_coords = get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ use_real=True,
+ )
+
+ return freqs_cos, freqs_sin
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @paddle.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 480,
+ width: int = 720,
+ num_frames: int = 49,
+ num_inference_steps: int = 50,
+ timesteps: Optional[List[int]] = None,
+ guidance_scale: float = 6,
+ use_dynamic_cfg: bool = False,
+ num_videos_per_prompt: int = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[paddle.Generator, List[paddle.Generator]]] = None,
+ latents: Optional[paddle.Tensor] = None,
+ prompt_embeds: Optional[paddle.Tensor] = None,
+ negative_prompt_embeds: Optional[paddle.Tensor] = None,
+ output_type: str = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 226,
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_frames (`int`, defaults to `48`):
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
+ num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that
+ needs to be satisfied is that of divisibility mentioned above.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `226`):
+ Maximum sequence length in encoded prompt. Must be consistent with
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
+ [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ if num_frames > 49:
+ raise ValueError(
+ "The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
+ )
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
+ width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ # 2. Default call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ negative_prompt,
+ do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_classifier_free_guidance:
+ prompt_embeds = paddle.concat([negative_prompt_embeds, prompt_embeds], axis=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, timesteps)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Create rotary embeds if required
+ image_rotary_emb = (
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1))
+ if self.transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ import numpy as np
+ # latents = paddle.to_tensor(np.load("../CogVideo/inference/latent.npy"), dtype=paddle.bfloat16)
+ # prompt_embeds = paddle.to_tensor(np.load("../CogVideo/inference/prompt.npy"), dtype=paddle.float32)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ # for DPM-solver++
+ old_pred_original_sample = None
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = paddle.concat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred.cast('float32')
+
+ # perform guidance
+ if use_dynamic_cfg:
+ self._guidance_scale = 1 + guidance_scale * (
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
+ )
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ else:
+ latents, old_pred_original_sample = self.scheduler.step(
+ noise_pred,
+ old_pred_original_sample,
+ t,
+ timesteps[i - 1] if i > 0 else None,
+ latents,
+ **extra_step_kwargs,
+ return_dict=False,
+ )
+ latents = latents.cast(prompt_embeds.dtype)
+
+ # call the callback, if provided
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if not output_type == "latent":
+ video = self.decode_latents(latents)
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return CogVideoXPipelineOutput(frames=video)
\ No newline at end of file
diff --git a/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_output.py b/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_output.py
new file mode 100644
index 000000000..4bf9efaa3
--- /dev/null
+++ b/ppdiffusers/ppdiffusers/pipelines/cogvideo/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import paddle
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class CogVideoXPipelineOutput(BaseOutput):
+ r"""
+ Output class for CogVideo pipelines.
+
+ Args:
+ frames (`paddle.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: paddle.Tensor
\ No newline at end of file
diff --git a/ppdiffusers/ppdiffusers/schedulers/__init__.py b/ppdiffusers/ppdiffusers/schedulers/__init__.py
index ae1ec2232..e914488e8 100644
--- a/ppdiffusers/ppdiffusers/schedulers/__init__.py
+++ b/ppdiffusers/ppdiffusers/schedulers/__init__.py
@@ -41,12 +41,14 @@
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
+ _import_structure["scheduling_ddim_cogvideox"] = ["CogVideoXDDIMScheduler"]
_import_structure["scheduling_ddim_inverse"] = ["DDIMInverseScheduler"]
_import_structure["scheduling_ddim_parallel"] = ["DDIMParallelScheduler"]
_import_structure["scheduling_ddpm"] = ["DDPMScheduler"]
_import_structure["scheduling_ddpm_parallel"] = ["DDPMParallelScheduler"]
_import_structure["scheduling_ddpm_wuerstchen"] = ["DDPMWuerstchenScheduler"]
_import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"]
+ _import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"]
_import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"]
_import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"]
_import_structure["scheduling_dpmsolver_singlestep"] = ["DPMSolverSinglestepScheduler"]
@@ -118,12 +120,14 @@
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler
+ from .scheduling_ddim_cogvideox import CogVideoXDDIMScheduler
from .scheduling_ddim_inverse import DDIMInverseScheduler
from .scheduling_ddim_parallel import DDIMParallelScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_ddpm_parallel import DDPMParallelScheduler
from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler
from .scheduling_deis_multistep import DEISMultistepScheduler
+ from .scheduling_dpm_cogvideox import CogVideoXDPMScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_dpmsolver_multistep_inverse import (
DPMSolverMultistepInverseScheduler,
diff --git a/ppdiffusers/ppdiffusers/schedulers/scheduling_ddim_cogvideox.py b/ppdiffusers/ppdiffusers/schedulers/scheduling_ddim_cogvideox.py
new file mode 100644
index 000000000..9cb7fcd85
--- /dev/null
+++ b/ppdiffusers/ppdiffusers/schedulers/scheduling_ddim_cogvideox.py
@@ -0,0 +1,341 @@
+import paddle
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+import numpy as np
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+
+
+@dataclass
+class DDIMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+ prev_sample: paddle.Tensor
+ pred_original_sample: Optional[paddle.Tensor] = None
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999,
+ alpha_transform_type='cosine'):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == 'cosine':
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+ elif alpha_transform_type == 'exp':
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+ else:
+ raise ValueError(
+ f'Unsupported alpha_transform_type: {alpha_transform_type}')
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return paddle.to_tensor(data=betas, dtype='float32')
+
+
+def rescale_zero_terminal_snr(alphas_cumprod):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.Tensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.Tensor`: rescaled betas with zero terminal SNR
+ """
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 -
+ alphas_bar_sqrt_T)
+ alphas_bar = alphas_bar_sqrt ** 2
+ return alphas_bar
+
+
+class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
+ non-Markovian guidance.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ clip_sample (`bool`, defaults to `True`):
+ Clip the predicted sample for numerical stability.
+ clip_sample_range (`float`, defaults to 1.0):
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, defaults to `True`):
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the alpha value at step 0.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps, as required by some model families.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(self, num_train_timesteps: int=1000, beta_start: float=
+ 0.00085, beta_end: float=0.012, beta_schedule: str='scaled_linear',
+ trained_betas: Optional[Union[np.ndarray, List[float]]]=None,
+ clip_sample: bool=True, set_alpha_to_one: bool=True, steps_offset:
+ int=0, prediction_type: str='epsilon', clip_sample_range: float=1.0,
+ sample_max_value: float=1.0, timestep_spacing: str='leading',
+ rescale_betas_zero_snr: bool=False, snr_shift_scale: float=3.0):
+ if trained_betas is not None:
+ self.betas = paddle.to_tensor(data=trained_betas, dtype='float32')
+ elif beta_schedule == 'linear':
+ self.betas = paddle.linspace(start=beta_start, stop=beta_end,
+ num=num_train_timesteps, dtype='float32')
+ elif beta_schedule == 'scaled_linear':
+ self.betas = paddle.linspace(start=beta_start ** 0.5, stop=
+ beta_end ** 0.5, num=num_train_timesteps, dtype='float64') ** 2
+ elif beta_schedule == 'squaredcos_cap_v2':
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(
+ f'{beta_schedule} is not implemented for {self.__class__}')
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = paddle.cumprod(x=self.alphas, dim=0)
+ self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 -
+ snr_shift_scale) * self.alphas_cumprod)
+ if rescale_betas_zero_snr:
+ self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod
+ )
+ self.final_alpha_cumprod = paddle.to_tensor(data=1.0
+ ) if set_alpha_to_one else self.alphas_cumprod[0]
+ self.init_noise_sigma = 1.0
+ self.num_inference_steps = None
+ self.timesteps = paddle.to_tensor(data=np.arange(0,
+ num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep
+ ] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+ variance = beta_prod_t_prev / beta_prod_t * (1 - alpha_prod_t /
+ alpha_prod_t_prev)
+ return variance
+
+ def scale_model_input(self, sample: paddle.Tensor, timestep: Optional[
+ int]=None) ->paddle.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ def set_timesteps(self, num_inference_steps: int):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f'`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`: {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle maximal {self.config.num_train_timesteps} timesteps.'
+ )
+ self.num_inference_steps = num_inference_steps
+ if self.config.timestep_spacing == 'linspace':
+ timesteps = np.linspace(0, self.config.num_train_timesteps - 1,
+ num_inference_steps).round()[::-1].copy().astype(np.int64)
+ elif self.config.timestep_spacing == 'leading':
+ step_ratio = (self.config.num_train_timesteps // self.
+ num_inference_steps)
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round(
+ )[::-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+ elif self.config.timestep_spacing == 'trailing':
+ step_ratio = (self.config.num_train_timesteps / self.
+ num_inference_steps)
+ timesteps = np.round(np.arange(self.config.num_train_timesteps,
+ 0, -step_ratio)).astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
+ )
+ self.timesteps = paddle.to_tensor(data=timesteps)
+
+ def step(self, model_output: paddle.Tensor, timestep: int, sample:
+ paddle.Tensor, eta: float=0.0, use_clipped_model_output: bool=False,
+ generator=None, variance_noise: Optional[paddle.Tensor]=None,
+ return_dict: bool=True) ->Union[DDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ eta (`float`):
+ The weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`, defaults to `False`):
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
+ `use_clipped_model_output` has no effect.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ variance_noise (`torch.Tensor`):
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
+ itself. Useful for methods such as [`CycleDiffusion`].
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+ prev_timestep = (timestep - self.config.num_train_timesteps // self
+ .num_inference_steps)
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep
+ ] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ if self.config.prediction_type == 'epsilon':
+ pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output
+ ) / alpha_prod_t ** 0.5
+ elif self.config.prediction_type == 'sample':
+ pred_original_sample = model_output
+ elif self.config.prediction_type == 'v_prediction':
+ pred_original_sample = (alpha_prod_t ** 0.5 * sample -
+ beta_prod_t ** 0.5 * model_output)
+ else:
+ raise ValueError(
+ f'prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or `v_prediction`'
+ )
+ a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5
+ b_t = alpha_prod_t_prev ** 0.5 - alpha_prod_t ** 0.5 * a_t
+ prev_sample = a_t * sample + b_t * pred_original_sample
+ if not return_dict:
+ return prev_sample,
+ return DDIMSchedulerOutput(prev_sample=prev_sample,
+ pred_original_sample=pred_original_sample)
+
+ def add_noise(self, original_samples: paddle.Tensor, noise: paddle.
+ Tensor, timesteps: paddle.Tensor) ->paddle.Tensor:
+ alphas_cumprod = self.alphas_cumprod.to(dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.place)
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(tuple(sqrt_alpha_prod.shape)) < len(tuple(
+ original_samples.shape)):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(axis=-1)
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(tuple(sqrt_one_minus_alpha_prod.shape)) < len(tuple(
+ original_samples.shape)):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(
+ axis=-1)
+ noisy_samples = (sqrt_alpha_prod * original_samples +
+ sqrt_one_minus_alpha_prod * noise)
+ return noisy_samples
+
+ def get_velocity(self, sample: paddle.Tensor, noise: paddle.Tensor,
+ timesteps: paddle.Tensor) ->paddle.Tensor:
+ alphas_cumprod = self.alphas_cumprod.to(dtype=sample.dtype)
+ timesteps = timesteps.to(sample.place)
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(tuple(sqrt_alpha_prod.shape)) < len(tuple(sample.shape)):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(axis=-1)
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(tuple(sqrt_one_minus_alpha_prod.shape)) < len(tuple(
+ sample.shape)):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(
+ axis=-1)
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/ppdiffusers/ppdiffusers/schedulers/scheduling_dpm_cogvideox.py b/ppdiffusers/ppdiffusers/schedulers/scheduling_dpm_cogvideox.py
new file mode 100644
index 000000000..9e679db98
--- /dev/null
+++ b/ppdiffusers/ppdiffusers/schedulers/scheduling_dpm_cogvideox.py
@@ -0,0 +1,453 @@
+import paddle
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+import numpy as np
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput
+from ..utils.paddle_utils import randn_tensor
+from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+
+
+@dataclass
+class DDIMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+ prev_sample: paddle.Tensor
+ pred_original_sample: Optional[paddle.Tensor] = None
+
+
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type='cosine'
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == 'cosine':
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == 'exp':
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(
+ f'Unsupported alpha_transform_type: {alpha_transform_type}')
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return paddle.to_tensor(betas, dtype='float32')
+
+
+def rescale_zero_terminal_snr(alphas_cumprod):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.Tensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.Tensor`: rescaled betas with zero terminal SNR
+ """
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+
+ return alphas_bar
+
+
+class CogVideoXDPMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
+ non-Markovian guidance.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
+ methods the library implements for all schedulers such as loading and saving.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ clip_sample (`bool`, defaults to `True`):
+ Clip the predicted sample for numerical stability.
+ clip_sample_range (`float`, defaults to 1.0):
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, defaults to `True`):
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the alpha value at step 0.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps, as required by some model families.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int=1000,
+ beta_start: float=0.00085,
+ beta_end: float=0.012,
+ beta_schedule: str='scaled_linear',
+ trained_betas: Optional[Union[np.ndarray, List[float]]]=None,
+ clip_sample: bool=True,
+ set_alpha_to_one: bool=True,
+ steps_offset: int=0,
+ prediction_type: str='epsilon',
+ clip_sample_range: float=1.0,
+ sample_max_value: float=1.0,
+ timestep_spacing: str='leading',
+ rescale_betas_zero_snr: bool=False,
+ snr_shift_scale: float=3.0
+ ):
+ if trained_betas is not None:
+ self.betas = paddle.to_tensor(trained_betas, dtype='float32')
+ elif beta_schedule == 'linear':
+ self.betas = paddle.linspace(beta_start, beta_end, num_train_timesteps, dtype='float32')
+ elif beta_schedule == 'scaled_linear':
+ self.betas = paddle.linspace(beta_start ** 0.5, beta_end ** 0.5, num_train_timesteps, dtype='float64') ** 2
+ elif beta_schedule == 'squaredcos_cap_v2':
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(
+ f'{beta_schedule} is not implemented for {self.__class__}')
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = paddle.cumprod(x=self.alphas, dim=0)
+
+ # Modify: SNR shift following SD3
+ self.alphas_cumprod = self.alphas_cumprod / (snr_shift_scale + (1 - snr_shift_scale) * self.alphas_cumprod)
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = paddle.to_tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = paddle.to_tensor(np.arange(0,num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ def _get_variance(
+ self,
+ timestep,
+ prev_timestep
+ ):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def scale_model_input(
+ self,
+ sample: paddle.Tensor,
+ timestep: Optional[int]=None
+ ) ->paddle.Tensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.Tensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+
+ Returns:
+ `torch.Tensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ def set_timesteps(self, num_inference_steps):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ """
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f'`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`: {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle maximal {self.config.num_train_timesteps} timesteps.'
+ )
+ self.num_inference_steps = num_inference_steps
+
+ if self.config.timestep_spacing == 'linspace':
+ timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps).round()[::-1].copy().astype(np.int64)
+
+ elif self.config.timestep_spacing == 'leading':
+ step_ratio = (self.config.num_train_timesteps // self. num_inference_steps)
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ timesteps += self.config.steps_offset
+
+ elif self.config.timestep_spacing == 'trailing':
+ step_ratio = (self.config.num_train_timesteps / self.num_inference_steps)
+ timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
+ timesteps -= 1
+ else:
+ raise ValueError(
+ f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'."
+ )
+ self.timesteps = paddle.to_tensor(timesteps)
+
+ def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None):
+ lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log()
+ lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log()
+ h = lamb_next - lamb
+
+ if alpha_prod_t_back is not None:
+ lamb_previous = ((alpha_prod_t_back / (1 - alpha_prod_t_back)) ** 0.5).log()
+ h_last = lamb - lamb_previous
+ r = h_last / h
+ return h, r, lamb, lamb_next
+ else:
+ return h, None, lamb, lamb_next
+
+ def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
+ mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp()
+ mult2 = (-2 * h).expm1() * alpha_prod_t_prev ** 0.5
+ if alpha_prod_t_back is not None:
+ mult3 = 1 + 1 / (2 * r)
+ mult4 = 1 / (2 * r)
+ return mult1, mult2, mult3, mult4
+ else:
+ return mult1, mult2
+
+ def step(
+ self,
+ model_output: paddle.Tensor,
+ old_pred_original_sample: paddle.Tensor,
+ timestep: int,
+ timestep_back: int,
+ sample: paddle.
+ Tensor, eta: float=0.0,
+ use_clipped_model_output: bool=False,
+ generator=None,
+ variance_noise: Optional[paddle.Tensor]=None,
+ return_dict: bool=False,
+ ) ->Union[DDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.Tensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.Tensor`):
+ A current instance of a sample created by the diffusion process.
+ eta (`float`):
+ The weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`, defaults to `False`):
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
+ `use_clipped_model_output` has no effect.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ variance_noise (`torch.Tensor`):
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
+ itself. Useful for methods such as [`CycleDiffusion`].
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
+
+ Returns:
+ [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ alpha_prod_t_back = self.alphas_cumprod[timestep_back] if timestep_back is not None else None
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # To make style tests pass, commented out `pred_epsilon` as it is an unused variable
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ # pred_epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ # pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ # pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction`"
+ )
+
+ h, r, lamb, lamb_next = self.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)
+ mult = list(self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back))
+ mult_noise = (1 - alpha_prod_t_prev) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5
+
+ noise = randn_tensor(sample.shape, generator=generator, dtype=sample.dtype)
+ prev_sample = mult[0] * sample - mult[1] * pred_original_sample + mult_noise * noise
+
+ if old_pred_original_sample is None or prev_timestep < 0:
+ # Save a network evaluation if all noise levels are 0 or on the first step
+ return prev_sample, pred_original_sample
+ else:
+ denoised_d = mult[2] * pred_original_sample - mult[3] * old_pred_original_sample
+ noise = randn_tensor(sample.shape, generator=generator, dtype=sample.dtype)
+ x_advanced = mult[0] * sample - mult[1] * denoised_d + mult_noise * noise
+
+ prev_sample = x_advanced
+
+ if not return_dict:
+ return (prev_sample, pred_original_sample)
+
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+
+ def add_noise(
+ self,
+ original_samples: paddle.Tensor,
+ noise: paddle.Tensor,
+ timesteps: paddle.Tensor
+ ) ->paddle.Tensor:
+
+ alphas_cumprod = self.alphas_cumprod.cast(original_samples.dtype)
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(axis=-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(axis=-1)
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+
+ return noisy_samples
+
+ def get_velocity(
+ self,
+ sample: paddle.Tensor,
+ noise: paddle.Tensor,
+ timesteps: paddle.Tensor
+ ) ->paddle.Tensor:
+ alphas_cumprod = self.alphas_cumprod.cast(sample.dtype)
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(axis=-1)
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(axis=-1)
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/ppdiffusers/ppdiffusers/utils/export_utils.py b/ppdiffusers/ppdiffusers/utils/export_utils.py
index 130ec3c05..df94ec2e2 100644
--- a/ppdiffusers/ppdiffusers/utils/export_utils.py
+++ b/ppdiffusers/ppdiffusers/utils/export_utils.py
@@ -161,15 +161,35 @@ def export_to_video(
video_writer.write(img)
return output_video_path
-
def export_to_video_2(
- video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 8
-):
+ video_frames: Union[List[np.ndarray], List[PIL.Image.Image]], output_video_path: str = None, fps: int = 10
+) -> str:
try:
import imageio
- except ImportError:
- raise ImportError("Please install imageio to export video.run `pip install imageio`")
+ except:
+ raise ImportError("imageio is not found")
+
+ try:
+ imageio.plugins.ffmpeg.get_exe()
+ except AttributeError:
+ raise AttributeError(
+ (
+ "Found an existing imageio backend in your environment. Attempting to export video with imageio. \n"
+ "Unable to find a compatible ffmpeg installation in your environment to use with imageio. Please install via `pip install imageio-ffmpeg"
+ )
+ )
+
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
- imageio.mimsave(output_video_path, video_frames, fps=fps, codec="mpeg4")
+ if isinstance(video_frames[0], np.ndarray):
+ video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]
+
+ elif isinstance(video_frames[0], PIL.Image.Image):
+ video_frames = [np.array(frame) for frame in video_frames]
+
+ with imageio.get_writer(output_video_path, fps=fps) as writer:
+ for frame in video_frames:
+ writer.append_data(frame)
+
+ return output_video_path
diff --git a/ppdiffusers/ppdiffusers/video_processor.py b/ppdiffusers/ppdiffusers/video_processor.py
new file mode 100644
index 000000000..f8c9e5d2a
--- /dev/null
+++ b/ppdiffusers/ppdiffusers/video_processor.py
@@ -0,0 +1,89 @@
+import paddle
+import warnings
+from typing import List, Optional, Union
+import numpy as np
+import PIL
+from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
+
+
+class VideoProcessor(VaeImageProcessor):
+ """Simple video processor."""
+
+ def preprocess_video(self, video, height: Optional[int]=None, width:
+ Optional[int]=None) ->paddle.Tensor:
+ """
+ Preprocesses input video(s).
+
+ Args:
+ video (`List[PIL.Image]`, `List[List[PIL.Image]]`, `torch.Tensor`, `np.array`, `List[torch.Tensor]`, `List[np.array]`):
+ The input video. It can be one of the following:
+ * List of the PIL images.
+ * List of list of PIL images.
+ * 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height, width)`).
+ * 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`).
+ * List of 4D Torch tensors (expected shape for each tensor `(num_frames, num_channels, height,
+ width)`).
+ * List of 4D NumPy arrays (expected shape for each array `(num_frames, height, width, num_channels)`).
+ * 5D NumPy arrays: expected shape for each array `(batch_size, num_frames, height, width,
+ num_channels)`.
+ * 5D Torch tensors: expected shape for each array `(batch_size, num_frames, num_channels, height,
+ width)`.
+ height (`int`, *optional*, defaults to `None`):
+ The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to
+ get default height.
+ width (`int`, *optional*`, defaults to `None`):
+ The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get
+ the default width.
+ """
+ if isinstance(video, list) and isinstance(video[0], np.ndarray
+ ) and video[0].ndim == 5:
+ warnings.warn(
+ 'Passing `video` as a list of 5d np.ndarray is deprecated.Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray'
+ , FutureWarning)
+ video = np.concatenate(video, axis=0)
+ if isinstance(video, list) and isinstance(video[0], paddle.Tensor
+ ) and video[0].ndim == 5:
+ warnings.warn(
+ 'Passing `video` as a list of 5d torch.Tensor is deprecated.Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor'
+ , FutureWarning)
+ video = paddle.concat(video, axis=0)
+ if isinstance(video, (np.ndarray, paddle.Tensor)) and video.ndim == 5:
+ video = list(video)
+ elif isinstance(video, list) and is_valid_image(video[0]
+ ) or is_valid_image_imagelist(video):
+ video = [video]
+ elif isinstance(video, list) and is_valid_image_imagelist(video[0]):
+ video = video
+ else:
+ raise ValueError(
+ 'Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image'
+ )
+ video = paddle.stack(x=[self.preprocess(img, height=height, width=
+ width) for img in video], axis=0)
+ video = video.transpose(perm=[0, 2, 1, 3, 4])
+ return video
+
+ def postprocess_video(self, video: paddle.Tensor, output_type: str='np'
+ ) ->Union[np.ndarray, paddle.Tensor, List[PIL.Image.Image]]:
+ """
+ Converts a video tensor to a list of frames for export.
+
+ Args:
+ video (`torch.Tensor`): The video as a tensor.
+ output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor.
+ """
+ batch_size = tuple(video.shape)[0]
+ outputs = []
+ for batch_idx in range(batch_size):
+ batch_vid = video[batch_idx].transpose(perm=[1, 0, 2, 3])
+ batch_output = self.postprocess(batch_vid, output_type)
+ outputs.append(batch_output)
+ if output_type == 'np':
+ outputs = np.stack(outputs)
+ elif output_type == 'pd':
+ outputs = paddle.stack(x=outputs)
+ elif not output_type == 'pil':
+ raise ValueError(
+ f"{output_type} does not exist. Please choose one of ['np', 'pd', 'pil']"
+ )
+ return outputs