diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index a50a70aec5c..9e6dfc17708 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -382,3 +382,7 @@ class HunyuanVideo(LatentFormat): ] latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] + +class Cosmos1CV8x8x8(LatentFormat): + latent_channels = 16 + latent_dimensions = 3 diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py new file mode 100644 index 00000000000..3e9c6497a03 --- /dev/null +++ b/comfy/ldm/cosmos/blocks.py @@ -0,0 +1,804 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional +import logging + +import numpy as np +import torch +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from torch import nn + +from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm +from comfy.ldm.modules.attention import optimized_attention + + +def apply_rotary_pos_emb( + t: torch.Tensor, + freqs: torch.Tensor, +) -> torch.Tensor: + t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() + t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] + t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) + return t_out + + +def get_normalization(name: str, channels: int, weight_args={}): + if name == "I": + return nn.Identity() + elif name == "R": + return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args) + else: + raise ValueError(f"Normalization {name} not found") + + +class BaseAttentionOp(nn.Module): + def __init__(self): + super().__init__() + + +class Attention(nn.Module): + """ + Generalized attention impl. + + Allowing for both self-attention and cross-attention configurations depending on whether a `context_dim` is provided. + If `context_dim` is None, self-attention is assumed. + + Parameters: + query_dim (int): Dimension of each query vector. + context_dim (int, optional): Dimension of each context vector. If None, self-attention is assumed. + heads (int, optional): Number of attention heads. Defaults to 8. + dim_head (int, optional): Dimension of each head. Defaults to 64. + dropout (float, optional): Dropout rate applied to the output of the attention block. Defaults to 0.0. + attn_op (BaseAttentionOp, optional): Custom attention operation to be used instead of the default. + qkv_bias (bool, optional): If True, adds a learnable bias to query, key, and value projections. Defaults to False. + out_bias (bool, optional): If True, adds a learnable bias to the output projection. Defaults to False. + qkv_norm (str, optional): A string representing normalization strategies for query, key, and value projections. + Defaults to "SSI". + qkv_norm_mode (str, optional): A string representing normalization mode for query, key, and value projections. + Defaults to 'per_head'. Only support 'per_head'. + + Examples: + >>> attn = Attention(query_dim=128, context_dim=256, heads=4, dim_head=32, dropout=0.1) + >>> query = torch.randn(10, 128) # Batch size of 10 + >>> context = torch.randn(10, 256) # Batch size of 10 + >>> output = attn(query, context) # Perform the attention operation + + Note: + https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + """ + + def __init__( + self, + query_dim: int, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + attn_op: Optional[BaseAttentionOp] = None, + qkv_bias: bool = False, + out_bias: bool = False, + qkv_norm: str = "SSI", + qkv_norm_mode: str = "per_head", + backend: str = "transformer_engine", + qkv_format: str = "bshd", + weight_args={}, + operations=None, + ) -> None: + super().__init__() + + self.is_selfattn = context_dim is None # self attention + + inner_dim = dim_head * heads + context_dim = query_dim if context_dim is None else context_dim + + self.heads = heads + self.dim_head = dim_head + self.qkv_norm_mode = qkv_norm_mode + self.qkv_format = qkv_format + + if self.qkv_norm_mode == "per_head": + norm_dim = dim_head + else: + raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'") + + self.backend = backend + + self.to_q = nn.Sequential( + operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args), + get_normalization(qkv_norm[0], norm_dim), + ) + self.to_k = nn.Sequential( + operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args), + get_normalization(qkv_norm[1], norm_dim), + ) + self.to_v = nn.Sequential( + operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args), + get_normalization(qkv_norm[2], norm_dim), + ) + + self.to_out = nn.Sequential( + operations.Linear(inner_dim, query_dim, bias=out_bias, **weight_args), + nn.Dropout(dropout), + ) + + def cal_qkv( + self, x, context=None, mask=None, rope_emb=None, **kwargs + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + del kwargs + + + """ + self.to_q, self.to_k, self.to_v are nn.Sequential with projection + normalization layers. + Before 07/24/2024, these modules normalize across all heads. + After 07/24/2024, to support tensor parallelism and follow the common practice in the community, + we support to normalize per head. + To keep the checkpoint copatibility with the previous code, + we keep the nn.Sequential but call the projection and the normalization layers separately. + We use a flag `self.qkv_norm_mode` to control the normalization behavior. + The default value of `self.qkv_norm_mode` is "per_head", which means we normalize per head. + """ + if self.qkv_norm_mode == "per_head": + q = self.to_q[0](x) + context = x if context is None else context + k = self.to_k[0](context) + v = self.to_v[0](context) + q, k, v = map( + lambda t: rearrange(t, "s b (n c) -> b n s c", n=self.heads, c=self.dim_head), + (q, k, v), + ) + else: + raise ValueError(f"Normalization mode {self.qkv_norm_mode} not found, only support 'per_head'") + + q = self.to_q[1](q) + k = self.to_k[1](k) + v = self.to_v[1](v) + if self.is_selfattn and rope_emb is not None: # only apply to self-attention! + q = apply_rotary_pos_emb(q, rope_emb) + k = apply_rotary_pos_emb(k, rope_emb) + return q, k, v + + def cal_attn(self, q, k, v, mask=None): + out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True) + out = rearrange(out, " b n s c -> s b (n c)") + return self.to_out(out) + + def forward( + self, + x, + context=None, + mask=None, + rope_emb=None, + **kwargs, + ): + """ + Args: + x (Tensor): The query tensor of shape [B, Mq, K] + context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None + """ + q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs) + return self.cal_attn(q, k, v, mask) + + +class FeedForward(nn.Module): + """ + Transformer FFN with optional gating + + Parameters: + d_model (int): Dimensionality of input features. + d_ff (int): Dimensionality of the hidden layer. + dropout (float, optional): Dropout rate applied after the activation function. Defaults to 0.1. + activation (callable, optional): The activation function applied after the first linear layer. + Defaults to nn.ReLU(). + is_gated (bool, optional): If set to True, incorporates gating mechanism to the feed-forward layer. + Defaults to False. + bias (bool, optional): If set to True, adds a bias to the linear layers. Defaults to True. + + Example: + >>> ff = FeedForward(d_model=512, d_ff=2048) + >>> x = torch.randn(64, 10, 512) # Example input tensor + >>> output = ff(x) + >>> print(output.shape) # Expected shape: (64, 10, 512) + """ + + def __init__( + self, + d_model: int, + d_ff: int, + dropout: float = 0.1, + activation=nn.ReLU(), + is_gated: bool = False, + bias: bool = False, + weight_args={}, + operations=None, + ) -> None: + super().__init__() + + self.layer1 = operations.Linear(d_model, d_ff, bias=bias, **weight_args) + self.layer2 = operations.Linear(d_ff, d_model, bias=bias, **weight_args) + + self.dropout = nn.Dropout(dropout) + self.activation = activation + self.is_gated = is_gated + if is_gated: + self.linear_gate = operations.Linear(d_model, d_ff, bias=False, **weight_args) + + def forward(self, x: torch.Tensor): + g = self.activation(self.layer1(x)) + if self.is_gated: + x = g * self.linear_gate(x) + else: + x = g + assert self.dropout.p == 0.0, "we skip dropout" + return self.layer2(x) + + +class GPT2FeedForward(FeedForward): + def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1, bias: bool = False, weight_args={}, operations=None): + super().__init__( + d_model=d_model, + d_ff=d_ff, + dropout=dropout, + activation=nn.GELU(), + is_gated=False, + bias=bias, + weight_args=weight_args, + operations=operations, + ) + + def forward(self, x: torch.Tensor): + assert self.dropout.p == 0.0, "we skip dropout" + + x = self.layer1(x) + x = self.activation(x) + x = self.layer2(x) + + return x + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class Timesteps(nn.Module): + def __init__(self, num_channels): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps): + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return emb + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, use_adaln_lora: bool = False, weight_args={}, operations=None): + super().__init__() + logging.debug( + f"Using AdaLN LoRA Flag: {use_adaln_lora}. We enable bias if no AdaLN LoRA for backward compatibility." + ) + self.linear_1 = operations.Linear(in_features, out_features, bias=not use_adaln_lora, **weight_args) + self.activation = nn.SiLU() + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.linear_2 = operations.Linear(out_features, 3 * out_features, bias=False, **weight_args) + else: + self.linear_2 = operations.Linear(out_features, out_features, bias=True, **weight_args) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + emb = self.linear_1(sample) + emb = self.activation(emb) + emb = self.linear_2(emb) + + if self.use_adaln_lora: + adaln_lora_B_3D = emb + emb_B_D = sample + else: + emb_B_D = emb + adaln_lora_B_3D = None + + return emb_B_D, adaln_lora_B_3D + + +class FourierFeatures(nn.Module): + """ + Implements a layer that generates Fourier features from input tensors, based on randomly sampled + frequencies and phases. This can help in learning high-frequency functions in low-dimensional problems. + + [B] -> [B, D] + + Parameters: + num_channels (int): The number of Fourier features to generate. + bandwidth (float, optional): The scaling factor for the frequency of the Fourier features. Defaults to 1. + normalize (bool, optional): If set to True, the outputs are scaled by sqrt(2), usually to normalize + the variance of the features. Defaults to False. + + Example: + >>> layer = FourierFeatures(num_channels=256, bandwidth=0.5, normalize=True) + >>> x = torch.randn(10, 256) # Example input tensor + >>> output = layer(x) + >>> print(output.shape) # Expected shape: (10, 256) + """ + + def __init__(self, num_channels, bandwidth=1, normalize=False): + super().__init__() + self.register_buffer("freqs", 2 * np.pi * bandwidth * torch.randn(num_channels), persistent=True) + self.register_buffer("phases", 2 * np.pi * torch.rand(num_channels), persistent=True) + self.gain = np.sqrt(2) if normalize else 1 + + def forward(self, x, gain: float = 1.0): + """ + Apply the Fourier feature transformation to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + gain (float, optional): An additional gain factor applied during the forward pass. Defaults to 1. + + Returns: + torch.Tensor: The transformed tensor, with Fourier features applied. + """ + in_dtype = x.dtype + x = x.to(torch.float32).ger(self.freqs.to(torch.float32)).add(self.phases.to(torch.float32)) + x = x.cos().mul(self.gain * gain).to(in_dtype) + return x + + +class PatchEmbed(nn.Module): + """ + PatchEmbed is a module for embedding patches from an input tensor by applying either 3D or 2D convolutional layers, + depending on the . This module can process inputs with temporal (video) and spatial (image) dimensions, + making it suitable for video and image processing tasks. It supports dividing the input into patches + and embedding each patch into a vector of size `out_channels`. + + Parameters: + - spatial_patch_size (int): The size of each spatial patch. + - temporal_patch_size (int): The size of each temporal patch. + - in_channels (int): Number of input channels. Default: 3. + - out_channels (int): The dimension of the embedding vector for each patch. Default: 768. + - bias (bool): If True, adds a learnable bias to the output of the convolutional layers. Default: True. + """ + + def __init__( + self, + spatial_patch_size, + temporal_patch_size, + in_channels=3, + out_channels=768, + bias=True, + weight_args={}, + operations=None, + ): + super().__init__() + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.proj = nn.Sequential( + Rearrange( + "b c (t r) (h m) (w n) -> b t h w (c r m n)", + r=temporal_patch_size, + m=spatial_patch_size, + n=spatial_patch_size, + ), + operations.Linear( + in_channels * spatial_patch_size * spatial_patch_size * temporal_patch_size, out_channels, bias=bias, **weight_args + ), + ) + self.out = nn.Identity() + + def forward(self, x): + """ + Forward pass of the PatchEmbed module. + + Parameters: + - x (torch.Tensor): The input tensor of shape (B, C, T, H, W) where + B is the batch size, + C is the number of channels, + T is the temporal dimension, + H is the height, and + W is the width of the input. + + Returns: + - torch.Tensor: The embedded patches as a tensor, with shape b t h w c. + """ + assert x.dim() == 5 + _, _, T, H, W = x.shape + assert H % self.spatial_patch_size == 0 and W % self.spatial_patch_size == 0 + assert T % self.temporal_patch_size == 0 + x = self.proj(x) + return self.out(x) + + +class FinalLayer(nn.Module): + """ + The final layer of video DiT. + """ + + def __init__( + self, + hidden_size, + spatial_patch_size, + temporal_patch_size, + out_channels, + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + weight_args={}, + operations=None, + ): + super().__init__() + self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **weight_args) + self.linear = operations.Linear( + hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, **weight_args + ) + self.hidden_size = hidden_size + self.n_adaln_chunks = 2 + self.use_adaln_lora = use_adaln_lora + if use_adaln_lora: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, adaln_lora_dim, bias=False, **weight_args), + operations.Linear(adaln_lora_dim, self.n_adaln_chunks * hidden_size, bias=False, **weight_args), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(hidden_size, self.n_adaln_chunks * hidden_size, bias=False, **weight_args) + ) + + def forward( + self, + x_BT_HW_D, + emb_B_D, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ): + if self.use_adaln_lora: + assert adaln_lora_B_3D is not None + shift_B_D, scale_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D[:, : 2 * self.hidden_size]).chunk( + 2, dim=1 + ) + else: + shift_B_D, scale_B_D = self.adaLN_modulation(emb_B_D).chunk(2, dim=1) + + B = emb_B_D.shape[0] + T = x_BT_HW_D.shape[0] // B + shift_BT_D, scale_BT_D = repeat(shift_B_D, "b d -> (b t) d", t=T), repeat(scale_B_D, "b d -> (b t) d", t=T) + x_BT_HW_D = modulate(self.norm_final(x_BT_HW_D), shift_BT_D, scale_BT_D) + + x_BT_HW_D = self.linear(x_BT_HW_D) + return x_BT_HW_D + + +class VideoAttn(nn.Module): + """ + Implements video attention with optional cross-attention capabilities. + + This module processes video features while maintaining their spatio-temporal structure. It can perform + self-attention within the video features or cross-attention with external context features. + + Parameters: + x_dim (int): Dimension of input feature vectors + context_dim (Optional[int]): Dimension of context features for cross-attention. None for self-attention + num_heads (int): Number of attention heads + bias (bool): Whether to include bias in attention projections. Default: False + qkv_norm_mode (str): Normalization mode for query/key/value projections. Must be "per_head". Default: "per_head" + x_format (str): Format of input tensor. Must be "BTHWD". Default: "BTHWD" + + Input shape: + - x: (T, H, W, B, D) video features + - context (optional): (M, B, D) context features for cross-attention + where: + T: temporal dimension + H: height + W: width + B: batch size + D: feature dimension + M: context sequence length + """ + + def __init__( + self, + x_dim: int, + context_dim: Optional[int], + num_heads: int, + bias: bool = False, + qkv_norm_mode: str = "per_head", + x_format: str = "BTHWD", + weight_args={}, + operations=None, + ) -> None: + super().__init__() + self.x_format = x_format + + self.attn = Attention( + x_dim, + context_dim, + num_heads, + x_dim // num_heads, + qkv_bias=bias, + qkv_norm="RRI", + out_bias=bias, + qkv_norm_mode=qkv_norm_mode, + qkv_format="sbhd", + weight_args=weight_args, + operations=operations, + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass for video attention. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D) representing batches of video data. + context (Tensor): Context tensor of shape (B, M, D) or (M, B, D), + where M is the sequence length of the context. + crossattn_mask (Optional[Tensor]): An optional mask for cross-attention mechanisms. + rope_emb_L_1_1_D (Optional[Tensor]): + Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. + + Returns: + Tensor: The output tensor with applied attention, maintaining the input shape. + """ + + x_T_H_W_B_D = x + context_M_B_D = context + T, H, W, B, D = x_T_H_W_B_D.shape + x_THW_B_D = rearrange(x_T_H_W_B_D, "t h w b d -> (t h w) b d") + x_THW_B_D = self.attn( + x_THW_B_D, + context_M_B_D, + crossattn_mask, + rope_emb=rope_emb_L_1_1_D, + ) + x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W) + return x_T_H_W_B_D + + +def adaln_norm_state(norm_state, x, scale, shift): + normalized = norm_state(x) + return normalized * (1 + scale) + shift + + +class DITBuildingBlock(nn.Module): + """ + A building block for the DiT (Diffusion Transformer) architecture that supports different types of + attention and MLP operations with adaptive layer normalization. + + Parameters: + block_type (str): Type of block - one of: + - "cross_attn"/"ca": Cross-attention + - "full_attn"/"fa": Full self-attention + - "mlp"/"ff": MLP/feedforward block + x_dim (int): Dimension of input features + context_dim (Optional[int]): Dimension of context features for cross-attention + num_heads (int): Number of attention heads + mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0 + bias (bool): Whether to use bias in layers. Default: False + mlp_dropout (float): Dropout rate for MLP. Default: 0.0 + qkv_norm_mode (str): QKV normalization mode. Default: "per_head" + x_format (str): Input tensor format. Default: "BTHWD" + use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False + adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256 + """ + + def __init__( + self, + block_type: str, + x_dim: int, + context_dim: Optional[int], + num_heads: int, + mlp_ratio: float = 4.0, + bias: bool = False, + mlp_dropout: float = 0.0, + qkv_norm_mode: str = "per_head", + x_format: str = "BTHWD", + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + weight_args={}, + operations=None + ) -> None: + block_type = block_type.lower() + + super().__init__() + self.x_format = x_format + if block_type in ["cross_attn", "ca"]: + self.block = VideoAttn( + x_dim, + context_dim, + num_heads, + bias=bias, + qkv_norm_mode=qkv_norm_mode, + x_format=self.x_format, + weight_args=weight_args, + operations=operations, + ) + elif block_type in ["full_attn", "fa"]: + self.block = VideoAttn( + x_dim, None, num_heads, bias=bias, qkv_norm_mode=qkv_norm_mode, x_format=self.x_format, weight_args=weight_args, operations=operations + ) + elif block_type in ["mlp", "ff"]: + self.block = GPT2FeedForward(x_dim, int(x_dim * mlp_ratio), dropout=mlp_dropout, bias=bias, weight_args=weight_args, operations=operations) + else: + raise ValueError(f"Unknown block type: {block_type}") + + self.block_type = block_type + self.use_adaln_lora = use_adaln_lora + + self.norm_state = nn.LayerNorm(x_dim, elementwise_affine=False, eps=1e-6) + self.n_adaln_chunks = 3 + if use_adaln_lora: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(x_dim, adaln_lora_dim, bias=False, **weight_args), + operations.Linear(adaln_lora_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args), + ) + else: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(x_dim, self.n_adaln_chunks * x_dim, bias=False, **weight_args)) + + def forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass for dynamically configured blocks with adaptive normalization. + + Args: + x (Tensor): Input tensor of shape (B, T, H, W, D) or (T, H, W, B, D). + emb_B_D (Tensor): Embedding tensor for adaptive layer normalization modulation. + crossattn_emb (Tensor): Tensor for cross-attention blocks. + crossattn_mask (Optional[Tensor]): Optional mask for cross-attention. + rope_emb_L_1_1_D (Optional[Tensor]): + Rotary positional embedding tensor of shape (L, 1, 1, D). L == THW for current video training. + + Returns: + Tensor: The output tensor after processing through the configured block and adaptive normalization. + """ + if self.use_adaln_lora: + shift_B_D, scale_B_D, gate_B_D = (self.adaLN_modulation(emb_B_D) + adaln_lora_B_3D).chunk( + self.n_adaln_chunks, dim=1 + ) + else: + shift_B_D, scale_B_D, gate_B_D = self.adaLN_modulation(emb_B_D).chunk(self.n_adaln_chunks, dim=1) + + shift_1_1_1_B_D, scale_1_1_1_B_D, gate_1_1_1_B_D = ( + shift_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + scale_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + gate_B_D.unsqueeze(0).unsqueeze(0).unsqueeze(0), + ) + + if self.block_type in ["mlp", "ff"]: + x = x + gate_1_1_1_B_D * self.block( + adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), + ) + elif self.block_type in ["full_attn", "fa"]: + x = x + gate_1_1_1_B_D * self.block( + adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), + context=None, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + elif self.block_type in ["cross_attn", "ca"]: + x = x + gate_1_1_1_B_D * self.block( + adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D), + context=crossattn_emb, + crossattn_mask=crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + ) + else: + raise ValueError(f"Unknown block type: {self.block_type}") + + return x + + +class GeneralDITTransformerBlock(nn.Module): + """ + A wrapper module that manages a sequence of DITBuildingBlocks to form a complete transformer layer. + Each block in the sequence is specified by a block configuration string. + + Parameters: + x_dim (int): Dimension of input features + context_dim (int): Dimension of context features for cross-attention blocks + num_heads (int): Number of attention heads + block_config (str): String specifying block sequence (e.g. "ca-fa-mlp" for cross-attention, + full-attention, then MLP) + mlp_ratio (float): MLP hidden dimension multiplier. Default: 4.0 + x_format (str): Input tensor format. Default: "BTHWD" + use_adaln_lora (bool): Whether to use AdaLN-LoRA. Default: False + adaln_lora_dim (int): Dimension for AdaLN-LoRA. Default: 256 + + The block_config string uses "-" to separate block types: + - "ca"/"cross_attn": Cross-attention block + - "fa"/"full_attn": Full self-attention block + - "mlp"/"ff": MLP/feedforward block + + Example: + block_config = "ca-fa-mlp" creates a sequence of: + 1. Cross-attention block + 2. Full self-attention block + 3. MLP block + """ + + def __init__( + self, + x_dim: int, + context_dim: int, + num_heads: int, + block_config: str, + mlp_ratio: float = 4.0, + x_format: str = "BTHWD", + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + weight_args={}, + operations=None + ): + super().__init__() + self.blocks = nn.ModuleList() + self.x_format = x_format + for block_type in block_config.split("-"): + self.blocks.append( + DITBuildingBlock( + block_type, + x_dim, + context_dim, + num_heads, + mlp_ratio, + x_format=self.x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + weight_args=weight_args, + operations=operations, + ) + ) + + def forward( + self, + x: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + rope_emb_L_1_1_D: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + extra_per_block_pos_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if extra_per_block_pos_emb is not None: + x = x + extra_per_block_pos_emb + for block in self.blocks: + x = block( + x, + emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + return x diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py b/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py new file mode 100644 index 00000000000..6149e53ec3b --- /dev/null +++ b/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py @@ -0,0 +1,1050 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The model definition for 3D layers + +Adapted from: https://github.com/lucidrains/magvit2-pytorch/blob/ +9f49074179c912736e617d61b32be367eb5f993a/magvit2_pytorch/magvit2_pytorch.py#L889 + +[MIT License Copyright (c) 2023 Phil Wang] +https://github.com/lucidrains/magvit2-pytorch/blob/ +9f49074179c912736e617d61b32be367eb5f993a/LICENSE +""" +import math +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import logging + +from .patching import ( + Patcher, + Patcher3D, + UnPatcher, + UnPatcher3D, +) +from .utils import ( + CausalNormalize, + batch2space, + batch2time, + cast_tuple, + is_odd, + nonlinearity, + replication_pad, + space2batch, + time2batch, +) + +import comfy.ops +ops = comfy.ops.disable_weight_init + +_LEGACY_NUM_GROUPS = 32 + + +class CausalConv3d(nn.Module): + def __init__( + self, + chan_in: int = 1, + chan_out: int = 1, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + pad_mode: str = "constant", + **kwargs, + ): + super().__init__() + kernel_size = cast_tuple(kernel_size, 3) + + time_kernel_size, height_kernel_size, width_kernel_size = kernel_size + + assert is_odd(height_kernel_size) and is_odd(width_kernel_size) + + dilation = kwargs.pop("dilation", 1) + stride = kwargs.pop("stride", 1) + time_stride = kwargs.pop("time_stride", 1) + time_dilation = kwargs.pop("time_dilation", 1) + padding = kwargs.pop("padding", 1) + + self.pad_mode = pad_mode + time_pad = time_dilation * (time_kernel_size - 1) + (1 - time_stride) + self.time_pad = time_pad + + self.spatial_pad = (padding, padding, padding, padding) + + stride = (time_stride, stride, stride) + dilation = (time_dilation, dilation, dilation) + self.conv3d = ops.Conv3d( + chan_in, + chan_out, + kernel_size, + stride=stride, + dilation=dilation, + **kwargs, + ) + + def _replication_pad(self, x: torch.Tensor) -> torch.Tensor: + x_prev = x[:, :, :1, ...].repeat(1, 1, self.time_pad, 1, 1) + x = torch.cat([x_prev, x], dim=2) + padding = self.spatial_pad + (0, 0) + return F.pad(x, padding, mode=self.pad_mode, value=0.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._replication_pad(x) + return self.conv3d(x) + + +class CausalUpsample3d(nn.Module): + def __init__(self, in_channels: int) -> None: + super().__init__() + self.conv = CausalConv3d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) + time_factor = 1.0 + 1.0 * (x.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + x = x.repeat_interleave(int(time_factor), dim=2) + # TODO(freda): Check if this causes temporal inconsistency. + # Shoule reverse the order of the following two ops, + # better perf and better temporal smoothness. + x = self.conv(x) + return x[..., int(time_factor - 1) :, :, :] + + +class CausalDownsample3d(nn.Module): + def __init__(self, in_channels: int) -> None: + super().__init__() + self.conv = CausalConv3d( + in_channels, + in_channels, + kernel_size=3, + stride=2, + time_stride=2, + padding=0, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad = (0, 1, 0, 1, 0, 0) + x = F.pad(x, pad, mode="constant", value=0) + x = replication_pad(x) + x = self.conv(x) + return x + + +class CausalHybridUpsample3d(nn.Module): + def __init__( + self, + in_channels: int, + spatial_up: bool = True, + temporal_up: bool = True, + **kwargs, + ) -> None: + super().__init__() + self.spatial_up = spatial_up + self.temporal_up = temporal_up + if not self.spatial_up and not self.temporal_up: + return + + self.conv1 = CausalConv3d( + in_channels, + in_channels, + kernel_size=(3, 1, 1), + stride=1, + time_stride=1, + padding=0, + ) + self.conv2 = CausalConv3d( + in_channels, + in_channels, + kernel_size=(1, 3, 3), + stride=1, + time_stride=1, + padding=1, + ) + self.conv3 = CausalConv3d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + time_stride=1, + padding=0, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_up and not self.temporal_up: + return x + + # hybrid upsample temporally. + if self.temporal_up: + time_factor = 1.0 + 1.0 * (x.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + x = x.repeat_interleave(int(time_factor), dim=2) + x = x[..., int(time_factor - 1) :, :, :] + x = self.conv1(x) + x + + # hybrid upsample spatially. + if self.spatial_up: + x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) + x = self.conv2(x) + x + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalHybridDownsample3d(nn.Module): + def __init__( + self, + in_channels: int, + spatial_down: bool = True, + temporal_down: bool = True, + **kwargs, + ) -> None: + super().__init__() + self.spatial_down = spatial_down + self.temporal_down = temporal_down + if not self.spatial_down and not self.temporal_down: + return + + self.conv1 = CausalConv3d( + in_channels, + in_channels, + kernel_size=(1, 3, 3), + stride=2, + time_stride=1, + padding=0, + ) + self.conv2 = CausalConv3d( + in_channels, + in_channels, + kernel_size=(3, 1, 1), + stride=1, + time_stride=2, + padding=0, + ) + self.conv3 = CausalConv3d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + time_stride=1, + padding=0, + ) + + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not self.spatial_down and not self.temporal_down: + return x + + # hybrid downsample spatially. + if self.spatial_down: + pad = (0, 1, 0, 1, 0, 0) + x = F.pad(x, pad, mode="constant", value=0) + x1 = self.conv1(x) + x2 = F.avg_pool3d(x, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + x = x1 + x2 + + # hybrid downsample temporally. + if self.temporal_down: + x = replication_pad(x) + x1 = self.conv2(x) + x2 = F.avg_pool3d(x, kernel_size=(2, 1, 1), stride=(2, 1, 1)) + x = x1 + x2 + + # final 1x1x1 conv. + x = self.conv3(x) + return x + + +class CausalResnetBlock3d(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int = None, + dropout: float, + num_groups: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = CausalNormalize(in_channels, num_groups=num_groups) + self.conv1 = CausalConv3d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = CausalNormalize(out_channels, num_groups=num_groups) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = CausalConv3d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.nin_shortcut = ( + CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + x = self.nin_shortcut(x) + + return x + h + + +class CausalResnetBlockFactorized3d(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: int = None, + dropout: float, + num_groups: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = CausalNormalize(in_channels, num_groups=1) + self.conv1 = nn.Sequential( + CausalConv3d( + in_channels, + out_channels, + kernel_size=(1, 3, 3), + stride=1, + padding=1, + ), + CausalConv3d( + out_channels, + out_channels, + kernel_size=(3, 1, 1), + stride=1, + padding=0, + ), + ) + self.norm2 = CausalNormalize(out_channels, num_groups=num_groups) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = nn.Sequential( + CausalConv3d( + out_channels, + out_channels, + kernel_size=(1, 3, 3), + stride=1, + padding=1, + ), + CausalConv3d( + out_channels, + out_channels, + kernel_size=(3, 1, 1), + stride=1, + padding=0, + ), + ) + self.nin_shortcut = ( + CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + x = self.nin_shortcut(x) + + return x + h + + +class CausalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size = time2batch(q) + k, batch_size = time2batch(k) + v, batch_size = time2batch(v) + + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h * w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c) ** (-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = batch2time(h_, batch_size) + h_ = self.proj_out(h_) + return x + h_ + + +class CausalTemporalAttnBlock(nn.Module): + def __init__(self, in_channels: int, num_groups: int) -> None: + super().__init__() + + self.norm = CausalNormalize(in_channels, num_groups=num_groups) + self.q = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = CausalConv3d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + q, batch_size, height = space2batch(q) + k, _, _ = space2batch(k) + v, _, _ = space2batch(v) + + bhw, c, t = q.shape + q = q.permute(0, 2, 1) # (bhw, t, c) + k = k.permute(0, 2, 1) # (bhw, t, c) + v = v.permute(0, 2, 1) # (bhw, t, c) + + w_ = torch.bmm(q, k.permute(0, 2, 1)) # (bhw, t, t) + w_ = w_ * (int(c) ** (-0.5)) + + # Apply causal mask + mask = torch.tril(torch.ones_like(w_)) + w_ = w_.masked_fill(mask == 0, float("-inf")) + w_ = F.softmax(w_, dim=2) + + # attend to values + h_ = torch.bmm(w_, v) # (bhw, t, c) + h_ = h_.permute(0, 2, 1).reshape(bhw, c, t) # (bhw, c, t) + + h_ = batch2space(h_, batch_size, height) + h_ = self.proj_out(h_) + return x + h_ + + +class EncoderBase(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + **ignore_kwargs, + ) -> None: + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher = Patcher( + patch_size, ignore_kwargs.get("patch_method", "rearrange") + ) + in_channels = in_channels * patch_size * patch_size + + # downsampling + self.conv_in = CausalConv3d( + in_channels, channels, kernel_size=3, stride=1, padding=1 + ) + + # num of groups for GroupNorm, num_groups=1 for LayerNorm. + num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS) + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=num_groups, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(CausalAttnBlock(block_in, num_groups=num_groups)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = CausalDownsample3d(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups) + self.mid.block_2 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + + # end + self.norm_out = CausalNormalize(block_in, num_groups=num_groups) + self.conv_out = CausalConv3d( + block_in, z_channels, kernel_size=3, stride=1, padding=1 + ) + + def patcher3d(self, x: torch.Tensor) -> torch.Tensor: + x, batch_size = time2batch(x) + x = self.patcher(x) + x = batch2time(x, batch_size) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher3d(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + else: + # temporal downsample (last level) + time_factor = 1 + 1 * (hs[-1].shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + hs[-1] = replication_pad(hs[-1]) + hs.append( + F.avg_pool3d( + hs[-1], + kernel_size=[time_factor, 1, 1], + stride=[2, 1, 1], + ) + ) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class DecoderBase(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher = UnPatcher( + patch_size, ignore_kwargs.get("patch_method", "rearrange") + ) + out_ch = out_channels * patch_size * patch_size + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logging.debug( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + # z to block_in + self.conv_in = CausalConv3d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # num of groups for GroupNorm, num_groups=1 for LayerNorm. + num_groups = ignore_kwargs.get("num_groups", _LEGACY_NUM_GROUPS) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + self.mid.attn_1 = CausalAttnBlock(block_in, num_groups=num_groups) + self.mid.block_2 = CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=num_groups, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + CausalResnetBlock3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=num_groups, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(CausalAttnBlock(block_in, num_groups=num_groups)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = CausalUpsample3d(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = CausalNormalize(block_in, num_groups=num_groups) + self.conv_out = CausalConv3d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def unpatcher3d(self, x: torch.Tensor) -> torch.Tensor: + x, batch_size = time2batch(x) + x = self.unpatcher(x) + x = batch2time(x, batch_size) + + return x + + def forward(self, z): + h = self.conv_in(z) + + # middle block. + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # decoder blocks. + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + else: + # temporal upsample (last level) + time_factor = 1.0 + 1.0 * (h.shape[2] > 1) + if isinstance(time_factor, torch.Tensor): + time_factor = time_factor.item() + h = h.repeat_interleave(int(time_factor), dim=2) + h = h[..., int(time_factor - 1) :, :, :] + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher3d(h) + return h + + +class EncoderFactorized(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int = 8, + temporal_compression: int = 8, + **ignore_kwargs, + ) -> None: + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # Patcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.patcher3d = Patcher3D( + patch_size, ignore_kwargs.get("patch_method", "haar") + ) + in_channels = in_channels * patch_size * patch_size * patch_size + + # calculate the number of downsample operations + self.num_spatial_downs = int(math.log2(spatial_compression)) - int( + math.log2(patch_size) + ) + assert ( + self.num_spatial_downs <= self.num_resolutions + ), f"Spatially downsample {self.num_resolutions} times at most" + + self.num_temporal_downs = int(math.log2(temporal_compression)) - int( + math.log2(patch_size) + ) + assert ( + self.num_temporal_downs <= self.num_resolutions + ), f"Temporally downsample {self.num_resolutions} times at most" + + # downsampling + self.conv_in = nn.Sequential( + CausalConv3d( + in_channels, + channels, + kernel_size=(1, 3, 3), + stride=1, + padding=1, + ), + CausalConv3d( + channels, channels, kernel_size=(3, 1, 1), stride=1, padding=0 + ), + ) + + curr_res = resolution // patch_size + in_ch_mult = (1,) + tuple(channels_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = channels * in_ch_mult[i_level] + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=1, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + ) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + spatial_down = i_level < self.num_spatial_downs + temporal_down = i_level < self.num_temporal_downs + down.downsample = CausalHybridDownsample3d( + block_in, + spatial_down=spatial_down, + temporal_down=temporal_down, + ) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d( + block_in, z_channels, kernel_size=(1, 3, 3), stride=1, padding=1 + ), + CausalConv3d( + z_channels, + z_channels, + kernel_size=(3, 1, 1), + stride=1, + padding=0, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patcher3d(x) + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class DecoderFactorized(nn.Module): + def __init__( + self, + out_channels: int, + channels: int, + channels_mult: list[int], + num_res_blocks: int, + attn_resolutions: list[int], + dropout: float, + resolution: int, + z_channels: int, + spatial_compression: int = 8, + temporal_compression: int = 8, + **ignore_kwargs, + ): + super().__init__() + self.num_resolutions = len(channels_mult) + self.num_res_blocks = num_res_blocks + + # UnPatcher. + patch_size = ignore_kwargs.get("patch_size", 1) + self.unpatcher3d = UnPatcher3D( + patch_size, ignore_kwargs.get("patch_method", "haar") + ) + out_ch = out_channels * patch_size * patch_size * patch_size + + # calculate the number of upsample operations + self.num_spatial_ups = int(math.log2(spatial_compression)) - int( + math.log2(patch_size) + ) + assert ( + self.num_spatial_ups <= self.num_resolutions + ), f"Spatially upsample {self.num_resolutions} times at most" + self.num_temporal_ups = int(math.log2(temporal_compression)) - int( + math.log2(patch_size) + ) + assert ( + self.num_temporal_ups <= self.num_resolutions + ), f"Temporally upsample {self.num_resolutions} times at most" + + block_in = channels * channels_mult[self.num_resolutions - 1] + curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logging.debug( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + # z to block_in + self.conv_in = nn.Sequential( + CausalConv3d( + z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=1 + ), + CausalConv3d( + block_in, block_in, kernel_size=(3, 1, 1), stride=1, padding=0 + ), + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + self.mid.attn_1 = nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + self.mid.block_2 = CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_in, + dropout=dropout, + num_groups=1, + ) + + legacy_mode = ignore_kwargs.get("legacy_mode", False) + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = channels * channels_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + CausalResnetBlockFactorized3d( + in_channels=block_in, + out_channels=block_out, + dropout=dropout, + num_groups=1, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append( + nn.Sequential( + CausalAttnBlock(block_in, num_groups=1), + CausalTemporalAttnBlock(block_in, num_groups=1), + ) + ) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + # The layer index for temporal/spatial downsampling performed + # in the encoder should correspond to the layer index in + # reverse order where upsampling is performed in the decoder. + # If you've a pre-trained model, you can simply finetune. + i_level_reverse = self.num_resolutions - i_level - 1 + if legacy_mode: + temporal_up = i_level_reverse < self.num_temporal_ups + else: + temporal_up = 0 < i_level_reverse < self.num_temporal_ups + 1 + spatial_up = temporal_up or ( + i_level_reverse < self.num_spatial_ups + and self.num_spatial_ups > self.num_temporal_ups + ) + up.upsample = CausalHybridUpsample3d( + block_in, spatial_up=spatial_up, temporal_up=temporal_up + ) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = CausalNormalize(block_in, num_groups=1) + self.conv_out = nn.Sequential( + CausalConv3d(block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=1), + CausalConv3d(out_ch, out_ch, kernel_size=(3, 1, 1), stride=1, padding=0), + ) + + def forward(self, z): + h = self.conv_in(z) + + # middle block. + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + + # decoder blocks. + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = self.unpatcher3d(h) + return h diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/patching.py b/comfy/ldm/cosmos/cosmos_tokenizer/patching.py new file mode 100644 index 00000000000..793f0da8a9e --- /dev/null +++ b/comfy/ldm/cosmos/cosmos_tokenizer/patching.py @@ -0,0 +1,355 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The patcher and unpatcher implementation for 2D and 3D data. + +The idea of Haar wavelet is to compute LL, LH, HL, HH component as two 1D convolutions. +One on the rows and one on the columns. +For example, in 1D signal, we have [a, b], then the low-freq compoenent is [a + b] / 2 and high-freq is [a - b] / 2. +We can use a 1D convolution with kernel [1, 1] and stride 2 to represent the L component. +For H component, we can use a 1D convolution with kernel [1, -1] and stride 2. +Although in principle, we typically only do additional Haar wavelet over the LL component. But here we do it for all + as we need to support downsampling for more than 2x. +For example, 4x downsampling can be done by 2x Haar and additional 2x Haar, and the shape would be. + [3, 256, 256] -> [12, 128, 128] -> [48, 64, 64] +""" + +import torch +import torch.nn.functional as F +from einops import rearrange + +_WAVELETS = { + "haar": torch.tensor([0.7071067811865476, 0.7071067811865476]), + "rearrange": torch.tensor([1.0, 1.0]), +} +_PERSISTENT = False + + +class Patcher(torch.nn.Module): + """A module to convert image tensors into patches using torch operations. + + The main difference from `class Patching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Patching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer( + "wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT + ) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer( + "_arange", + torch.arange(_WAVELETS[patch_method].shape[0]), + persistent=_PERSISTENT, + ) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._haar(x) + elif self.patch_method == "rearrange": + return self._arrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _dwt(self, x, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets.to(device=x.device) + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype) + xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2)) + xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2)) + xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1)) + xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1)) + xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1)) + + out = torch.cat([xll, xlh, xhl, xhh], dim=1) + if rescale: + out = out / 2 + return out + + def _haar(self, x): + for _ in self.range: + x = self._dwt(x, rescale=True) + return x + + def _arrange(self, x): + x = rearrange( + x, + "b c (h p1) (w p2) -> b (c p1 p2) h w", + p1=self.patch_size, + p2=self.patch_size, + ).contiguous() + return x + + +class Patcher3D(Patcher): + """A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + self.register_buffer( + "patch_size_buffer", + patch_size * torch.ones([1], dtype=torch.int32), + persistent=_PERSISTENT, + ) + + def _dwt(self, x, wavelet, mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets.to(device=x.device) + + n = h.shape[0] + g = x.shape[1] + hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1) + hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + # Handles temporal axis. + x = F.pad( + x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode + ).to(dtype) + xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)) + + # Handles spatial axes. + xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)) + + xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)) + + out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1) + if rescale: + out = out / (2 * torch.sqrt(torch.tensor(2.0))) + return out + + def _haar(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + for _ in self.range: + x = self._dwt(x, "haar", rescale=True) + return x + + def _arrange(self, x): + xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2) + x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2) + x = rearrange( + x, + "b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ).contiguous() + return x + + +class UnPatcher(torch.nn.Module): + """A module to convert patches into image tensorsusing torch operations. + + The main difference from `class Unpatching` is that this module implements + all operations using torch, rather than python or numpy, for efficiency purpose. + + It's bit-wise identical to the Unpatching module outputs, with the added + benefit of being torch.jit scriptable. + """ + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__() + self.patch_size = patch_size + self.patch_method = patch_method + self.register_buffer( + "wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT + ) + self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item())) + self.register_buffer( + "_arange", + torch.arange(_WAVELETS[patch_method].shape[0]), + persistent=_PERSISTENT, + ) + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + if self.patch_method == "haar": + return self._ihaar(x) + elif self.patch_method == "rearrange": + return self._iarrange(x) + else: + raise ValueError("Unknown patch method: " + self.patch_method) + + def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets.to(device=x.device) + n = h.shape[0] + + g = x.shape[1] // 4 + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1) + hh = hh.to(dtype=dtype) + hl = hl.to(dtype=dtype) + + xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1) + + # Inverse transform. + yl = torch.nn.functional.conv_transpose2d( + xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) + ) + yl += torch.nn.functional.conv_transpose2d( + xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) + ) + yh = torch.nn.functional.conv_transpose2d( + xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) + ) + yh += torch.nn.functional.conv_transpose2d( + xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0) + ) + y = torch.nn.functional.conv_transpose2d( + yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2) + ) + y += torch.nn.functional.conv_transpose2d( + yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2) + ) + + if rescale: + y = y * 2 + return y + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, "haar", rescale=True) + return x + + def _iarrange(self, x): + x = rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.patch_size, + p2=self.patch_size, + ) + return x + + +class UnPatcher3D(UnPatcher): + """A 3D inverse discrete wavelet transform for video wavelet decompositions.""" + + def __init__(self, patch_size=1, patch_method="haar"): + super().__init__(patch_method=patch_method, patch_size=patch_size) + + def _idwt(self, x, wavelet="haar", mode="reflect", rescale=False): + dtype = x.dtype + h = self.wavelets.to(device=x.device) + + g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors. + hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1]) + hh = (h * ((-1) ** self._arange.to(device=x.device))).reshape(1, 1, -1).repeat(g, 1, 1) + hl = hl.to(dtype=dtype) + hh = hh.to(dtype=dtype) + + xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1) + + # Height height transposed convolutions. + xll = F.conv_transpose3d( + xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + xll += F.conv_transpose3d( + xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + + xlh = F.conv_transpose3d( + xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + xlh += F.conv_transpose3d( + xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + + xhl = F.conv_transpose3d( + xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + xhl += F.conv_transpose3d( + xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + + xhh = F.conv_transpose3d( + xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + xhh += F.conv_transpose3d( + xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2) + ) + + # Handles width transposed convolutions. + xl = F.conv_transpose3d( + xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) + ) + xl += F.conv_transpose3d( + xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) + ) + xh = F.conv_transpose3d( + xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) + ) + xh += F.conv_transpose3d( + xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1) + ) + + # Handles time axis transposed convolutions. + x = F.conv_transpose3d( + xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1) + ) + x += F.conv_transpose3d( + xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1) + ) + + if rescale: + x = x * (2 * torch.sqrt(torch.tensor(2.0))) + return x + + def _ihaar(self, x): + for _ in self.range: + x = self._idwt(x, "haar", rescale=True) + x = x[:, :, self.patch_size - 1 :, ...] + return x + + def _iarrange(self, x): + x = rearrange( + x, + "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)", + p1=self.patch_size, + p2=self.patch_size, + p3=self.patch_size, + ) + x = x[:, :, self.patch_size - 1 :, ...] + return x diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/utils.py b/comfy/ldm/cosmos/cosmos_tokenizer/utils.py new file mode 100644 index 00000000000..64dd5e288a0 --- /dev/null +++ b/comfy/ldm/cosmos/cosmos_tokenizer/utils.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared utilities for the networks module.""" + +from typing import Any + +import torch +from einops import pack, rearrange, unpack + + +import comfy.ops +ops = comfy.ops.disable_weight_init + +def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size = x.shape[0] + return rearrange(x, "b c t h w -> (b t) c h w"), batch_size + + +def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor: + return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) + + +def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: + batch_size, height = x.shape[0], x.shape[-2] + return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height + + +def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor: + return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height) + + +def cast_tuple(t: Any, length: int = 1) -> Any: + return t if isinstance(t, tuple) else ((t,) * length) + + +def replication_pad(x): + return torch.cat([x[:, :, :1, ...], x], dim=2) + + +def divisible_by(num: int, den: int) -> bool: + return (num % den) == 0 + + +def is_odd(n: int) -> bool: + return not divisible_by(n, 2) + + +def nonlinearity(x): + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return ops.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class CausalNormalize(torch.nn.Module): + def __init__(self, in_channels, num_groups=1): + super().__init__() + self.norm = ops.GroupNorm( + num_groups=num_groups, + num_channels=in_channels, + eps=1e-6, + affine=True, + ) + self.num_groups = num_groups + + def forward(self, x): + # if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose. + # All new models should use num_groups=1, otherwise causality is not guaranteed. + if self.num_groups == 1: + x, batch_size = time2batch(x) + return batch2time(self.norm(x), batch_size) + return self.norm(x) + + +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg + return None + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def round_ste(z: torch.Tensor) -> torch.Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + + +def log(t, eps=1e-5): + return t.clamp(min=eps).log() + + +def entropy(prob): + return (-prob * log(prob)).sum(dim=-1) diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py new file mode 100644 index 00000000000..05dd3846960 --- /dev/null +++ b/comfy/ldm/cosmos/model.py @@ -0,0 +1,510 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. +""" + +from typing import Optional, Tuple + +import torch +from einops import rearrange +from torch import nn +from torchvision import transforms + +from enum import Enum +import logging + +from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm + +from .blocks import ( + FinalLayer, + GeneralDITTransformerBlock, + PatchEmbed, + TimestepEmbedding, + Timesteps, +) + +from .position_embedding import LearnablePosEmbAxis, VideoRopePosition3DEmb + + +class DataType(Enum): + IMAGE = "image" + VIDEO = "video" + + +class GeneralDIT(nn.Module): + """ + A general implementation of adaln-modulated VIT-like~(DiT) transformer for video processing. + + Args: + max_img_h (int): Maximum height of the input images. + max_img_w (int): Maximum width of the input images. + max_frames (int): Maximum number of frames in the video sequence. + in_channels (int): Number of input channels (e.g., RGB channels for color images). + out_channels (int): Number of output channels. + patch_spatial (tuple): Spatial resolution of patches for input processing. + patch_temporal (int): Temporal resolution of patches for input processing. + concat_padding_mask (bool): If True, includes a mask channel in the input to handle padding. + block_config (str): Configuration of the transformer block. See Notes for supported block types. + model_channels (int): Base number of channels used throughout the model. + num_blocks (int): Number of transformer blocks. + num_heads (int): Number of heads in the multi-head attention layers. + mlp_ratio (float): Expansion ratio for MLP blocks. + block_x_format (str): Format of input tensor for transformer blocks ('BTHWD' or 'THWBD'). + crossattn_emb_channels (int): Number of embedding channels for cross-attention. + use_cross_attn_mask (bool): Whether to use mask in cross-attention. + pos_emb_cls (str): Type of positional embeddings. + pos_emb_learnable (bool): Whether positional embeddings are learnable. + pos_emb_interpolation (str): Method for interpolating positional embeddings. + affline_emb_norm (bool): Whether to normalize affine embeddings. + use_adaln_lora (bool): Whether to use AdaLN-LoRA. + adaln_lora_dim (int): Dimension for AdaLN-LoRA. + rope_h_extrapolation_ratio (float): Height extrapolation ratio for RoPE. + rope_w_extrapolation_ratio (float): Width extrapolation ratio for RoPE. + rope_t_extrapolation_ratio (float): Temporal extrapolation ratio for RoPE. + extra_per_block_abs_pos_emb (bool): Whether to use extra per-block absolute positional embeddings. + extra_per_block_abs_pos_emb_type (str): Type of extra per-block positional embeddings. + extra_h_extrapolation_ratio (float): Height extrapolation ratio for extra embeddings. + extra_w_extrapolation_ratio (float): Width extrapolation ratio for extra embeddings. + extra_t_extrapolation_ratio (float): Temporal extrapolation ratio for extra embeddings. + + Notes: + Supported block types in block_config: + * cross_attn, ca: Cross attention + * full_attn: Full attention on all flattened tokens + * mlp, ff: Feed forward block + """ + + def __init__( + self, + max_img_h: int, + max_img_w: int, + max_frames: int, + in_channels: int, + out_channels: int, + patch_spatial: tuple, + patch_temporal: int, + concat_padding_mask: bool = True, + # attention settings + block_config: str = "FA-CA-MLP", + model_channels: int = 768, + num_blocks: int = 10, + num_heads: int = 16, + mlp_ratio: float = 4.0, + block_x_format: str = "BTHWD", + # cross attention settings + crossattn_emb_channels: int = 1024, + use_cross_attn_mask: bool = False, + # positional embedding settings + pos_emb_cls: str = "sincos", + pos_emb_learnable: bool = False, + pos_emb_interpolation: str = "crop", + affline_emb_norm: bool = False, # whether or not to normalize the affine embedding + use_adaln_lora: bool = False, + adaln_lora_dim: int = 256, + rope_h_extrapolation_ratio: float = 1.0, + rope_w_extrapolation_ratio: float = 1.0, + rope_t_extrapolation_ratio: float = 1.0, + extra_per_block_abs_pos_emb: bool = False, + extra_per_block_abs_pos_emb_type: str = "sincos", + extra_h_extrapolation_ratio: float = 1.0, + extra_w_extrapolation_ratio: float = 1.0, + extra_t_extrapolation_ratio: float = 1.0, + image_model=None, + device=None, + dtype=None, + operations=None, + ) -> None: + super().__init__() + self.max_img_h = max_img_h + self.max_img_w = max_img_w + self.max_frames = max_frames + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.num_heads = num_heads + self.num_blocks = num_blocks + self.model_channels = model_channels + self.use_cross_attn_mask = use_cross_attn_mask + self.concat_padding_mask = concat_padding_mask + # positional embedding settings + self.pos_emb_cls = pos_emb_cls + self.pos_emb_learnable = pos_emb_learnable + self.pos_emb_interpolation = pos_emb_interpolation + self.affline_emb_norm = affline_emb_norm + self.rope_h_extrapolation_ratio = rope_h_extrapolation_ratio + self.rope_w_extrapolation_ratio = rope_w_extrapolation_ratio + self.rope_t_extrapolation_ratio = rope_t_extrapolation_ratio + self.extra_per_block_abs_pos_emb = extra_per_block_abs_pos_emb + self.extra_per_block_abs_pos_emb_type = extra_per_block_abs_pos_emb_type.lower() + self.extra_h_extrapolation_ratio = extra_h_extrapolation_ratio + self.extra_w_extrapolation_ratio = extra_w_extrapolation_ratio + self.extra_t_extrapolation_ratio = extra_t_extrapolation_ratio + self.dtype = dtype + weight_args = {"device": device, "dtype": dtype} + + in_channels = in_channels + 1 if concat_padding_mask else in_channels + self.x_embedder = PatchEmbed( + spatial_patch_size=patch_spatial, + temporal_patch_size=patch_temporal, + in_channels=in_channels, + out_channels=model_channels, + bias=False, + weight_args=weight_args, + operations=operations, + ) + + self.build_pos_embed(device=device) + self.block_x_format = block_x_format + self.use_adaln_lora = use_adaln_lora + self.adaln_lora_dim = adaln_lora_dim + self.t_embedder = nn.ModuleList( + [Timesteps(model_channels), + TimestepEmbedding(model_channels, model_channels, use_adaln_lora=use_adaln_lora, weight_args=weight_args, operations=operations),] + ) + + self.blocks = nn.ModuleDict() + + for idx in range(num_blocks): + self.blocks[f"block{idx}"] = GeneralDITTransformerBlock( + x_dim=model_channels, + context_dim=crossattn_emb_channels, + num_heads=num_heads, + block_config=block_config, + mlp_ratio=mlp_ratio, + x_format=self.block_x_format, + use_adaln_lora=use_adaln_lora, + adaln_lora_dim=adaln_lora_dim, + weight_args=weight_args, + operations=operations, + ) + + if self.affline_emb_norm: + logging.debug("Building affine embedding normalization layer") + self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6) + else: + self.affline_norm = nn.Identity() + + self.final_layer = FinalLayer( + hidden_size=self.model_channels, + spatial_patch_size=self.patch_spatial, + temporal_patch_size=self.patch_temporal, + out_channels=self.out_channels, + use_adaln_lora=self.use_adaln_lora, + adaln_lora_dim=self.adaln_lora_dim, + weight_args=weight_args, + operations=operations, + ) + + def build_pos_embed(self, device=None): + if self.pos_emb_cls == "rope3d": + cls_type = VideoRopePosition3DEmb + else: + raise ValueError(f"Unknown pos_emb_cls {self.pos_emb_cls}") + + logging.debug(f"Building positional embedding with {self.pos_emb_cls} class, impl {cls_type}") + kwargs = dict( + model_channels=self.model_channels, + len_h=self.max_img_h // self.patch_spatial, + len_w=self.max_img_w // self.patch_spatial, + len_t=self.max_frames // self.patch_temporal, + is_learnable=self.pos_emb_learnable, + interpolation=self.pos_emb_interpolation, + head_dim=self.model_channels // self.num_heads, + h_extrapolation_ratio=self.rope_h_extrapolation_ratio, + w_extrapolation_ratio=self.rope_w_extrapolation_ratio, + t_extrapolation_ratio=self.rope_t_extrapolation_ratio, + device=device, + ) + self.pos_embedder = cls_type( + **kwargs, + ) + + if self.extra_per_block_abs_pos_emb: + assert self.extra_per_block_abs_pos_emb_type in [ + "learnable", + ], f"Unknown extra_per_block_abs_pos_emb_type {self.extra_per_block_abs_pos_emb_type}" + kwargs["h_extrapolation_ratio"] = self.extra_h_extrapolation_ratio + kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio + kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio + kwargs["device"] = device + self.extra_pos_embedder = LearnablePosEmbAxis( + **kwargs, + ) + + def prepare_embedded_sequence( + self, + x_B_C_T_H_W: torch.Tensor, + fps: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks. + + Args: + x_B_C_T_H_W (torch.Tensor): video + fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required. + If None, a default value (`self.base_fps`) will be used. + padding_mask (Optional[torch.Tensor]): current it is not used + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - A tensor of shape (B, T, H, W, D) with the embedded sequence. + - An optional positional embedding tensor, returned only if the positional embedding class + (`self.pos_emb_cls`) includes 'rope'. Otherwise, None. + + Notes: + - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor. + - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`. + - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using + the `self.pos_embedder` with the shape [T, H, W]. + - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the + `self.pos_embedder` with the fps tensor. + - Otherwise, the positional embeddings are generated without considering fps. + """ + if self.concat_padding_mask: + if padding_mask is not None: + padding_mask = transforms.functional.resize( + padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST + ) + else: + padding_mask = torch.zeros((x_B_C_T_H_W.shape[0], 1, x_B_C_T_H_W.shape[-2], x_B_C_T_H_W.shape[-1]), dtype=x_B_C_T_H_W.dtype, device=x_B_C_T_H_W.device) + + x_B_C_T_H_W = torch.cat( + [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1 + ) + x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W) + + if self.extra_per_block_abs_pos_emb: + extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) + else: + extra_pos_emb = None + + if "rope" in self.pos_emb_cls.lower(): + return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device), extra_pos_emb + + if "fps_aware" in self.pos_emb_cls: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device) # [B, T, H, W, D] + else: + x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, device=x_B_C_T_H_W.device) # [B, T, H, W, D] + + return x_B_T_H_W_D, None, extra_pos_emb + + def decoder_head( + self, + x_B_T_H_W_D: torch.Tensor, + emb_B_D: torch.Tensor, + crossattn_emb: torch.Tensor, + origin_shape: Tuple[int, int, int, int, int], # [B, C, T, H, W] + crossattn_mask: Optional[torch.Tensor] = None, + adaln_lora_B_3D: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + del crossattn_emb, crossattn_mask + B, C, T_before_patchify, H_before_patchify, W_before_patchify = origin_shape + x_BT_HW_D = rearrange(x_B_T_H_W_D, "B T H W D -> (B T) (H W) D") + x_BT_HW_D = self.final_layer(x_BT_HW_D, emb_B_D, adaln_lora_B_3D=adaln_lora_B_3D) + # This is to ensure x_BT_HW_D has the correct shape because + # when we merge T, H, W into one dimension, x_BT_HW_D has shape (B * T * H * W, 1*1, D). + x_BT_HW_D = x_BT_HW_D.view( + B * T_before_patchify // self.patch_temporal, + H_before_patchify // self.patch_spatial * W_before_patchify // self.patch_spatial, + -1, + ) + x_B_D_T_H_W = rearrange( + x_BT_HW_D, + "(B T) (H W) (p1 p2 t C) -> B C (T t) (H p1) (W p2)", + p1=self.patch_spatial, + p2=self.patch_spatial, + H=H_before_patchify // self.patch_spatial, + W=W_before_patchify // self.patch_spatial, + t=self.patch_temporal, + B=B, + ) + return x_B_D_T_H_W + + def forward_before_blocks( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + crossattn_emb: torch.Tensor, + crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + """ + del kwargs + assert isinstance( + data_type, DataType + ), f"Expected DataType, got {type(data_type)}. We need discuss this flag later." + original_shape = x.shape + x_B_T_H_W_D, rope_emb_L_1_1_D, extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = self.prepare_embedded_sequence( + x, + fps=fps, + padding_mask=padding_mask, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + ) + # logging affline scale information + affline_scale_log_info = {} + + timesteps_B_D, adaln_lora_B_3D = self.t_embedder[1](self.t_embedder[0](timesteps.flatten()).to(x.dtype)) + affline_emb_B_D = timesteps_B_D + affline_scale_log_info["timesteps_B_D"] = timesteps_B_D.detach() + + if scalar_feature is not None: + raise NotImplementedError("Scalar feature is not implemented yet.") + + affline_scale_log_info["affline_emb_B_D"] = affline_emb_B_D.detach() + affline_emb_B_D = self.affline_norm(affline_emb_B_D) + + if self.use_cross_attn_mask: + if crossattn_mask is not None and not torch.is_floating_point(crossattn_mask): + crossattn_mask = (crossattn_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max + crossattn_mask = crossattn_mask[:, None, None, :] # .to(dtype=torch.bool) # [B, 1, 1, length] + else: + crossattn_mask = None + + if self.blocks["block0"].x_format == "THWBD": + x = rearrange(x_B_T_H_W_D, "B T H W D -> T H W B D") + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = rearrange( + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "B T H W D -> T H W B D" + ) + crossattn_emb = rearrange(crossattn_emb, "B M D -> M B D") + + if crossattn_mask: + crossattn_mask = rearrange(crossattn_mask, "B M -> M B") + + elif self.blocks["block0"].x_format == "BTHWD": + x = x_B_T_H_W_D + else: + raise ValueError(f"Unknown x_format {self.blocks[0].x_format}") + output = { + "x": x, + "affline_emb_B_D": affline_emb_B_D, + "crossattn_emb": crossattn_emb, + "crossattn_mask": crossattn_mask, + "rope_emb_L_1_1_D": rope_emb_L_1_1_D, + "adaln_lora_B_3D": adaln_lora_B_3D, + "original_shape": original_shape, + "extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + } + return output + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + # crossattn_emb: torch.Tensor, + # crossattn_mask: Optional[torch.Tensor] = None, + fps: Optional[torch.Tensor] = None, + image_size: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + scalar_feature: Optional[torch.Tensor] = None, + data_type: Optional[DataType] = DataType.VIDEO, + latent_condition: Optional[torch.Tensor] = None, + latent_condition_sigma: Optional[torch.Tensor] = None, + condition_video_augment_sigma: Optional[torch.Tensor] = None, + **kwargs, + ): + """ + Args: + x: (B, C, T, H, W) tensor of spatial-temp inputs + timesteps: (B, ) tensor of timesteps + crossattn_emb: (B, N, D) tensor of cross-attention embeddings + crossattn_mask: (B, N) tensor of cross-attention masks + condition_video_augment_sigma: (B,) used in lvg(long video generation), we add noise with this sigma to + augment condition input, the lvg model will condition on the condition_video_augment_sigma value; + we need forward_before_blocks pass to the forward_before_blocks function. + """ + + crossattn_emb = context + crossattn_mask = attention_mask + + inputs = self.forward_before_blocks( + x=x, + timesteps=timesteps, + crossattn_emb=crossattn_emb, + crossattn_mask=crossattn_mask, + fps=fps, + image_size=image_size, + padding_mask=padding_mask, + scalar_feature=scalar_feature, + data_type=data_type, + latent_condition=latent_condition, + latent_condition_sigma=latent_condition_sigma, + condition_video_augment_sigma=condition_video_augment_sigma, + **kwargs, + ) + x, affline_emb_B_D, crossattn_emb, crossattn_mask, rope_emb_L_1_1_D, adaln_lora_B_3D, original_shape = ( + inputs["x"], + inputs["affline_emb_B_D"], + inputs["crossattn_emb"], + inputs["crossattn_mask"], + inputs["rope_emb_L_1_1_D"], + inputs["adaln_lora_B_3D"], + inputs["original_shape"], + ) + extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype) + if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None: + assert ( + x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape + ), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}" + + for _, block in self.blocks.items(): + assert ( + self.blocks["block0"].x_format == block.x_format + ), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}" + + x = block( + x, + affline_emb_B_D, + crossattn_emb, + crossattn_mask, + rope_emb_L_1_1_D=rope_emb_L_1_1_D, + adaln_lora_B_3D=adaln_lora_B_3D, + extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, + ) + + x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D") + + x_B_D_T_H_W = self.decoder_head( + x_B_T_H_W_D=x_B_T_H_W_D, + emb_B_D=affline_emb_B_D, + crossattn_emb=None, + origin_shape=original_shape, + crossattn_mask=None, + adaln_lora_B_3D=adaln_lora_B_3D, + ) + + return x_B_D_T_H_W diff --git a/comfy/ldm/cosmos/position_embedding.py b/comfy/ldm/cosmos/position_embedding.py new file mode 100644 index 00000000000..dda752cb875 --- /dev/null +++ b/comfy/ldm/cosmos/position_embedding.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +from einops import rearrange, repeat +from torch import nn +import math + + +def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0) -> torch.Tensor: + """ + Normalizes the input tensor along specified dimensions such that the average square norm of elements is adjusted. + + Args: + x (torch.Tensor): The input tensor to normalize. + dim (list, optional): The dimensions over which to normalize. If None, normalizes over all dimensions except the first. + eps (float, optional): A small constant to ensure numerical stability during division. + + Returns: + torch.Tensor: The normalized tensor. + """ + if dim is None: + dim = list(range(1, x.ndim)) + norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) + norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel())) + return x / norm.to(x.dtype) + + +class VideoPositionEmb(nn.Module): + def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None) -> torch.Tensor: + """ + It delegates the embedding generation to generate_embeddings function. + """ + B_T_H_W_C = x_B_T_H_W_C.shape + embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device) + + return embeddings + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None): + raise NotImplementedError + + +class VideoRopePosition3DEmb(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + head_dim: int, + len_h: int, + len_w: int, + len_t: int, + base_fps: int = 24, + h_extrapolation_ratio: float = 1.0, + w_extrapolation_ratio: float = 1.0, + t_extrapolation_ratio: float = 1.0, + device=None, + **kwargs, # used for compatibility with other positional embeddings; unused in this class + ): + del kwargs + super().__init__() + self.register_buffer("seq", torch.arange(max(len_h, len_w, len_t), dtype=torch.float, device=device)) + self.base_fps = base_fps + self.max_h = len_h + self.max_w = len_w + + dim = head_dim + dim_h = dim // 6 * 2 + dim_w = dim_h + dim_t = dim - 2 * dim_h + assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}" + self.register_buffer( + "dim_spatial_range", + torch.arange(0, dim_h, 2, device=device)[: (dim_h // 2)].float() / dim_h, + persistent=False, + ) + self.register_buffer( + "dim_temporal_range", + torch.arange(0, dim_t, 2, device=device)[: (dim_t // 2)].float() / dim_t, + persistent=False, + ) + + self.h_ntk_factor = h_extrapolation_ratio ** (dim_h / (dim_h - 2)) + self.w_ntk_factor = w_extrapolation_ratio ** (dim_w / (dim_w - 2)) + self.t_ntk_factor = t_extrapolation_ratio ** (dim_t / (dim_t - 2)) + + def generate_embeddings( + self, + B_T_H_W_C: torch.Size, + fps: Optional[torch.Tensor] = None, + h_ntk_factor: Optional[float] = None, + w_ntk_factor: Optional[float] = None, + t_ntk_factor: Optional[float] = None, + device=None, + ): + """ + Generate embeddings for the given input size. + + Args: + B_T_H_W_C (torch.Size): Input tensor size (Batch, Time, Height, Width, Channels). + fps (Optional[torch.Tensor], optional): Frames per second. Defaults to None. + h_ntk_factor (Optional[float], optional): Height NTK factor. If None, uses self.h_ntk_factor. + w_ntk_factor (Optional[float], optional): Width NTK factor. If None, uses self.w_ntk_factor. + t_ntk_factor (Optional[float], optional): Time NTK factor. If None, uses self.t_ntk_factor. + + Returns: + Not specified in the original code snippet. + """ + h_ntk_factor = h_ntk_factor if h_ntk_factor is not None else self.h_ntk_factor + w_ntk_factor = w_ntk_factor if w_ntk_factor is not None else self.w_ntk_factor + t_ntk_factor = t_ntk_factor if t_ntk_factor is not None else self.t_ntk_factor + + h_theta = 10000.0 * h_ntk_factor + w_theta = 10000.0 * w_ntk_factor + t_theta = 10000.0 * t_ntk_factor + + h_spatial_freqs = 1.0 / (h_theta**self.dim_spatial_range.to(device=device)) + w_spatial_freqs = 1.0 / (w_theta**self.dim_spatial_range.to(device=device)) + temporal_freqs = 1.0 / (t_theta**self.dim_temporal_range.to(device=device)) + + B, T, H, W, _ = B_T_H_W_C + uniform_fps = (fps is None) or isinstance(fps, (int, float)) or (fps.min() == fps.max()) + assert ( + uniform_fps or B == 1 or T == 1 + ), "For video batch, batch size should be 1 for non-uniform fps. For image batch, T should be 1" + assert ( + H <= self.max_h and W <= self.max_w + ), f"Input dimensions (H={H}, W={W}) exceed the maximum dimensions (max_h={self.max_h}, max_w={self.max_w})" + half_emb_h = torch.outer(self.seq[:H].to(device=device), h_spatial_freqs) + half_emb_w = torch.outer(self.seq[:W].to(device=device), w_spatial_freqs) + + # apply sequence scaling in temporal dimension + if fps is None: # image case + half_emb_t = torch.outer(self.seq[:T].to(device=device), temporal_freqs) + else: + half_emb_t = torch.outer(self.seq[:T].to(device=device) / fps * self.base_fps, temporal_freqs) + + half_emb_h = torch.stack([torch.cos(half_emb_h), -torch.sin(half_emb_h), torch.sin(half_emb_h), torch.cos(half_emb_h)], dim=-1) + half_emb_w = torch.stack([torch.cos(half_emb_w), -torch.sin(half_emb_w), torch.sin(half_emb_w), torch.cos(half_emb_w)], dim=-1) + half_emb_t = torch.stack([torch.cos(half_emb_t), -torch.sin(half_emb_t), torch.sin(half_emb_t), torch.cos(half_emb_t)], dim=-1) + + em_T_H_W_D = torch.cat( + [ + repeat(half_emb_t, "t d x -> t h w d x", h=H, w=W), + repeat(half_emb_h, "h d x -> t h w d x", t=T, w=W), + repeat(half_emb_w, "w d x -> t h w d x", t=T, h=H), + ] + , dim=-2, + ) + + return rearrange(em_T_H_W_D, "t h w d (i j) -> (t h w) d i j", i=2, j=2).float() + + +class LearnablePosEmbAxis(VideoPositionEmb): + def __init__( + self, + *, # enforce keyword arguments + interpolation: str, + model_channels: int, + len_h: int, + len_w: int, + len_t: int, + device=None, + **kwargs, + ): + """ + Args: + interpolation (str): we curretly only support "crop", ideally when we need extrapolation capacity, we should adjust frequency or other more advanced methods. they are not implemented yet. + """ + del kwargs # unused + super().__init__() + self.interpolation = interpolation + assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}" + + self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device)) + self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device)) + self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device)) + + + def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor: + B, T, H, W, _ = B_T_H_W_C + if self.interpolation == "crop": + emb_h_H = self.pos_emb_h[:H].to(device=device) + emb_w_W = self.pos_emb_w[:W].to(device=device) + emb_t_T = self.pos_emb_t[:T].to(device=device) + emb = ( + repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W) + + repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W) + + repeat(emb_w_W, "w d-> b t h w d", b=B, t=T, h=H) + ) + assert list(emb.shape)[:4] == [B, T, H, W], f"bad shape: {list(emb.shape)[:4]} != {B, T, H, W}" + else: + raise ValueError(f"Unknown interpolation method {self.interpolation}") + + return normalize(emb, dim=-1, eps=1e-6) diff --git a/comfy/ldm/cosmos/vae.py b/comfy/ldm/cosmos/vae.py new file mode 100644 index 00000000000..94fcc54ce89 --- /dev/null +++ b/comfy/ldm/cosmos/vae.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The causal continuous video tokenizer with VAE or AE formulation for 3D data..""" + +import logging +import torch +from torch import nn +from enum import Enum + +from .cosmos_tokenizer.layers3d import ( + EncoderFactorized, + DecoderFactorized, + CausalConv3d, +) + + +class IdentityDistribution(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, parameters): + return parameters, (torch.tensor([0.0]), torch.tensor([0.0])) + + +class GaussianDistribution(torch.nn.Module): + def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0): + super().__init__() + self.min_logvar = min_logvar + self.max_logvar = max_logvar + + def sample(self, mean, logvar): + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + + def forward(self, parameters): + mean, logvar = torch.chunk(parameters, 2, dim=1) + logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar) + return self.sample(mean, logvar), (mean, logvar) + + +class ContinuousFormulation(Enum): + VAE = GaussianDistribution + AE = IdentityDistribution + + +class CausalContinuousVideoTokenizer(nn.Module): + def __init__( + self, z_channels: int, z_factor: int, latent_channels: int, **kwargs + ) -> None: + super().__init__() + self.name = kwargs.get("name", "CausalContinuousVideoTokenizer") + self.latent_channels = latent_channels + self.sigma_data = 0.5 + + # encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name) + self.encoder = EncoderFactorized( + z_channels=z_factor * z_channels, **kwargs + ) + if kwargs.get("temporal_compression", 4) == 4: + kwargs["channels_mult"] = [2, 4] + # decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name) + self.decoder = DecoderFactorized( + z_channels=z_channels, **kwargs + ) + + self.quant_conv = CausalConv3d( + z_factor * z_channels, + z_factor * latent_channels, + kernel_size=1, + padding=0, + ) + self.post_quant_conv = CausalConv3d( + latent_channels, z_channels, kernel_size=1, padding=0 + ) + + # formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name) + self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value() + + num_parameters = sum(param.numel() for param in self.parameters()) + logging.info(f"model={self.name}, num_parameters={num_parameters:,}") + logging.info( + f"z_channels={z_channels}, latent_channels={self.latent_channels}." + ) + + latent_temporal_chunk = 16 + self.latent_mean = nn.Parameter(torch.zeros([self.latent_channels * latent_temporal_chunk], dtype=torch.float32)) + self.latent_std = nn.Parameter(torch.ones([self.latent_channels * latent_temporal_chunk], dtype=torch.float32)) + + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + z, posteriors = self.distribution(moments) + latent_ch = z.shape[1] + latent_t = z.shape[2] + dtype = z.dtype + mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device) + std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device) + return ((z - mean) / std) * self.sigma_data + + def decode(self, z): + in_dtype = z.dtype + latent_ch = z.shape[1] + latent_t = z.shape[2] + mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) + std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device) + + z = z / self.sigma_data + z = z * std + mean + z = self.post_quant_conv(z) + return self.decoder(z) + diff --git a/comfy/model_base.py b/comfy/model_base.py index 141f3f4072b..a67504cbbe0 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -33,6 +33,7 @@ import comfy.ldm.flux.model import comfy.ldm.lightricks.model import comfy.ldm.hunyuan_video.model +import comfy.ldm.cosmos.model import comfy.model_management import comfy.patcher_extension @@ -856,3 +857,19 @@ def extra_conds(self, **kwargs): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)])) return out + +class CosmosVideo(BaseModel): + def __init__(self, model_config, model_type=ModelType.EDM, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None)) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index f5a33cd9b5f..20cd6bb86d4 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -239,6 +239,50 @@ def detect_unet_config(state_dict, key_prefix): dit_config["micro_condition"] = False return dit_config + if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys: + dit_config = {} + dit_config["image_model"] = "cosmos" + dit_config["max_img_h"] = 240 + dit_config["max_img_w"] = 240 + dit_config["max_frames"] = 128 + dit_config["in_channels"] = 16 + dit_config["out_channels"] = 16 + dit_config["patch_spatial"] = 2 + dit_config["patch_temporal"] = 1 + dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0] + dit_config["block_config"] = "FA-CA-MLP" + dit_config["concat_padding_mask"] = True + dit_config["pos_emb_cls"] = "rope3d" + dit_config["pos_emb_learnable"] = False + dit_config["pos_emb_interpolation"] = "crop" + dit_config["block_x_format"] = "THWBD" + dit_config["affline_emb_norm"] = True + dit_config["use_adaln_lora"] = True + dit_config["adaln_lora_dim"] = 256 + + if dit_config["model_channels"] == 4096: + # 7B + dit_config["num_blocks"] = 28 + dit_config["num_heads"] = 32 + dit_config["extra_per_block_abs_pos_emb"] = True + dit_config["rope_h_extrapolation_ratio"] = 1.0 + dit_config["rope_w_extrapolation_ratio"] = 1.0 + dit_config["rope_t_extrapolation_ratio"] = 2.0 + dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" + else: # 5120 + # 14B + dit_config["num_blocks"] = 36 + dit_config["num_heads"] = 40 + dit_config["extra_per_block_abs_pos_emb"] = True + dit_config["rope_h_extrapolation_ratio"] = 2.0 + dit_config["rope_w_extrapolation_ratio"] = 2.0 + dit_config["rope_t_extrapolation_ratio"] = 2.0 + dit_config["extra_h_extrapolation_ratio"] = 2.0 + dit_config["extra_w_extrapolation_ratio"] = 2.0 + dit_config["extra_t_extrapolation_ratio"] = 2.0 + dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -393,6 +437,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal def unet_prefix_from_state_dict(state_dict): candidates = ["model.diffusion_model.", #ldm/sgm models "model.model.", #audio models + "net.", #cosmos ] counts = {k: 0 for k in candidates} for k in state_dict: diff --git a/comfy/sd.py b/comfy/sd.py index f3c65f9eefd..7db1c2d6097 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -11,6 +11,7 @@ from .ldm.audio.autoencoder import AudioOobleckVAE import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder +import comfy.ldm.cosmos.vae import yaml import math @@ -34,6 +35,7 @@ import comfy.text_encoders.genmo import comfy.text_encoders.lt import comfy.text_encoders.hunyuan_video +import comfy.text_encoders.cosmos import comfy.model_patcher import comfy.lora @@ -376,6 +378,19 @@ def __init__(self, sd=None, device=None, config=None, dtype=None): self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] + elif "decoder.unpatcher3d.wavelets" in sd: + self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8) + self.upscale_index_formula = (8, 8, 8) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 8, 8) + self.downscale_index_formula = (8, 8, 8) + self.latent_dim = 3 + self.latent_channels = 16 + ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8} + self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig) + #TODO: these values are a bit off because this is not a standard VAE + self.memory_used_decode = lambda shape, dtype: (220 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (500 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.working_dtypes = [torch.bfloat16, torch.float32] else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -641,6 +656,7 @@ class CLIPType(Enum): LTXV = 8 HUNYUAN_VIDEO = 9 PIXART = 10 + COSMOS = 11 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -658,6 +674,7 @@ class TEModel(Enum): T5_XL = 5 T5_BASE = 6 LLAMA3_8 = 7 + T5_XXL_OLD = 8 def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -672,6 +689,8 @@ def detect_te_model(sd): return TEModel.T5_XXL elif weight.shape[-1] == 2048: return TEModel.T5_XL + if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd: + return TEModel.T5_XXL_OLD if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd: return TEModel.T5_BASE if "model.layers.0.post_attention_layernorm.weight" in sd: @@ -681,9 +700,10 @@ def detect_te_model(sd): def t5xxl_detect(clip_data): weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" + weight_name_old = "encoder.block.23.layer.1.DenseReluDense.wi.weight" for sd in clip_data: - if weight_name in sd: + if weight_name in sd or weight_name_old in sd: return comfy.text_encoders.sd3_clip.t5_xxl_detect(sd) return {} @@ -740,6 +760,9 @@ class EmptyClass: else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer + elif te_model == TEModel.T5_XXL_OLD: + clip_target.clip = comfy.text_encoders.cosmos.te(**t5xxl_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.cosmos.CosmosT5Tokenizer elif te_model == TEModel.T5_XL: clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6a2cc75ae5b..31de1ae9e8d 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -14,6 +14,7 @@ import comfy.text_encoders.genmo import comfy.text_encoders.lt import comfy.text_encoders.hunyuan_video +import comfy.text_encoders.cosmos from . import supported_models_base from . import latent_formats @@ -823,6 +824,37 @@ def clip_target(self, state_dict={}): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect)) -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo] +class Cosmos(supported_models_base.BASE): + unet_config = { + "image_model": "cosmos", + } + + sampling_settings = { + "sigma_data": 0.5, + "sigma_max": 80.0, + "sigma_min": 0.002, + } + + unet_extra_config = {} + latent_format = latent_formats.Cosmos1CV8x8x8 + + memory_usage_factor = 2.4 #TODO + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.CosmosVideo(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect)) + + +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, Cosmos] models += [SVD_img2vid] diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py new file mode 100644 index 00000000000..5441c895256 --- /dev/null +++ b/comfy/text_encoders/cosmos.py @@ -0,0 +1,42 @@ +from comfy import sd1_clip +import comfy.text_encoders.t5 +import os +from transformers import T5TokenizerFast + + +class T5XXLModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json") + t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) + if t5xxl_scaled_fp8 is not None: + model_options = model_options.copy() + model_options["scaled_fp8"] = t5xxl_scaled_fp8 + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options) + +class CosmosT5XXL(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options) + + +class T5XXLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512) + + +class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) + + +def te(dtype_t5=None, t5xxl_scaled_fp8=None): + class CosmosTEModel_(CosmosT5XXL): + def __init__(self, device="cpu", dtype=None, model_options={}): + if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + if dtype is None: + dtype = dtype_t5 + super().__init__(device=device, dtype=dtype, model_options=model_options) + return CosmosTEModel_ diff --git a/comfy/text_encoders/t5_old_config_xxl.json b/comfy/text_encoders/t5_old_config_xxl.json new file mode 100644 index 00000000000..c9fdd778219 --- /dev/null +++ b/comfy/text_encoders/t5_old_config_xxl.json @@ -0,0 +1,22 @@ +{ + "d_ff": 65536, + "d_kv": 128, + "d_model": 1024, + "decoder_start_token_id": 0, + "dropout_rate": 0.1, + "eos_token_id": 1, + "dense_act_fn": "relu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": false, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 24, + "num_heads": 128, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "vocab_size": 32128 +} diff --git a/comfy_extras/nodes_cosmos.py b/comfy_extras/nodes_cosmos.py new file mode 100644 index 00000000000..d88773e259c --- /dev/null +++ b/comfy_extras/nodes_cosmos.py @@ -0,0 +1,23 @@ +import nodes +import torch +import comfy.model_management + +class EmptyCosmosLatentVideo: + @classmethod + def INPUT_TYPES(s): + return {"required": { "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), + "length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "latent/video" + + def generate(self, width, height, length, batch_size=1): + latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + return ({"samples":latent}, ) + +NODE_CLASS_MAPPINGS = { + "EmptyCosmosLatentVideo": EmptyCosmosLatentVideo, +} diff --git a/nodes.py b/nodes.py index 1b796ced7be..cfd7dd8a453 100644 --- a/nodes.py +++ b/nodes.py @@ -912,7 +912,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -922,7 +922,7 @@ def INPUT_TYPES(s): CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5" + DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 / clip-g / clip-l\nstable_audio: t5\nmochi: t5\ncosmos: old t5 xxl" def load_clip(self, clip_name, type="stable_diffusion", device="default"): if type == "stable_cascade": @@ -2225,6 +2225,7 @@ def init_builtin_extra_nodes(): "nodes_lt.py", "nodes_hooks.py", "nodes_load_3d.py", + "nodes_cosmos.py", ] import_failed = []