diff --git a/mindone/diffusers/README.md b/mindone/diffusers/README.md index dcb86ea3d9..4b63f10a6f 100644 --- a/mindone/diffusers/README.md +++ b/mindone/diffusers/README.md @@ -97,20 +97,7 @@ Most base, utility and mixin class are available. - [ ] StableDiffusionPipeline ### Model - -#### AutoEncoders - -- [x] AutoencoderKL - -#### UNets - -- [x] UNet1DModel -- [x] UNet2DConditionModel -- [x] StableCascadeUNet - -#### Transformers - -- [x] Transformer2DModel +- All Supported ### Scheduler - [x] DDIMScheduler/DDPMScheduler/...(30) @@ -134,6 +121,7 @@ Most base, utility and mixin class are available. Unlike the output `posterior = DiagonalGaussianDistribution(latent)`, which can do sampling by `posterior.sample()`. We can only output the `latent` and then do sampling through `AutoencoderKL.diag_gauss_dist.sample(latent)`. + ## Credits Hacked together @geniuspatrick. diff --git a/mindone/diffusers/__init__.py b/mindone/diffusers/__init__.py index d3024f5efd..0035a62b96 100644 --- a/mindone/diffusers/__init__.py +++ b/mindone/diffusers/__init__.py @@ -14,13 +14,31 @@ _import_structure = { "configuration_utils": ["ConfigMixin"], "models": [ + "AsymmetricAutoencoderKL", "AutoencoderKL", + "AutoencoderKLTemporalDecoder", + "AutoencoderTiny", + "ConsistencyDecoderVAE", + "ControlNetModel", + "I2VGenXLUNet", + "Kandinsky3UNet", "ModelMixin", + "MotionAdapter", + "MultiAdapter", + "PriorTransformer", + "T2IAdapter", + "T5FilmDecoder", + "Transformer2DModel", "SD3Transformer2DModel", + "StableCascadeUNet", "UNet1DModel", "UNet2DConditionModel", "UNet2DModel", - "StableCascadeUNet", + "UNet3DConditionModel", + "UNetMotionModel", + "UNetSpatioTemporalConditionModel", + "UVit2DModel", + "VQModel", ], "optimization": [ "get_constant_schedule", @@ -80,13 +98,31 @@ if TYPE_CHECKING: from .configuration_utils import ConfigMixin from .models import ( + AsymmetricAutoencoderKL, AutoencoderKL, + AutoencoderKLTemporalDecoder, + AutoencoderTiny, + ConsistencyDecoderVAE, + ControlNetModel, + I2VGenXLUNet, + Kandinsky3UNet, ModelMixin, + MotionAdapter, + MultiAdapter, + PriorTransformer, SD3Transformer2DModel, StableCascadeUNet, + T2IAdapter, + T5FilmDecoder, + Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, + UNet3DConditionModel, + UNetMotionModel, + UNetSpatioTemporalConditionModel, + UVit2DModel, + VQModel, ) from .optimization import ( get_constant_schedule, diff --git a/mindone/diffusers/models/__init__.py b/mindone/diffusers/models/__init__.py index 9440cbb1af..84f00e2df4 100644 --- a/mindone/diffusers/models/__init__.py +++ b/mindone/diffusers/models/__init__.py @@ -18,26 +18,67 @@ _import_structure = { "adapter": ["MultiAdapter", "T2IAdapter"], + "autoencoders.autoencoder_asym_kl": ["AsymmetricAutoencoderKL"], "autoencoders.autoencoder_kl": ["AutoencoderKL"], + "autoencoders.autoencoder_kl_temporal_decoder": ["AutoencoderKLTemporalDecoder"], + "autoencoders.autoencoder_tiny": ["AutoencoderTiny"], + "autoencoders.consistency_decoder_vae": ["ConsistencyDecoderVAE"], "controlnet": ["ControlNetModel"], + "dual_transformer_2d": ["DualTransformer2DModel"], "embeddings": ["ImageProjection"], "modeling_utils": ["ModelMixin"], + "transformers.prior_transformer": ["PriorTransformer"], + "transformers.t5_film_transformer": ["T5FilmDecoder"], "transformers.transformer_2d": ["Transformer2DModel"], + "transformers.transformer_temporal": ["TransformerTemporalModel"], "transformers.transformer_sd3": ["SD3Transformer2DModel"], "unets.unet_1d": ["UNet1DModel"], "unets.unet_2d": ["UNet2DModel"], "unets.unet_2d_condition": ["UNet2DConditionModel"], + "unets.unet_3d_condition": ["UNet3DConditionModel"], + "unets.unet_i2vgen_xl": ["I2VGenXLUNet"], + "unets.unet_kandinsky3": ["Kandinsky3UNet"], + "unets.unet_motion_model": ["MotionAdapter", "UNetMotionModel"], "unets.unet_stable_cascade": ["StableCascadeUNet"], + "unets.unet_spatio_temporal_condition": ["UNetSpatioTemporalConditionModel"], + "unets.uvit_2d": ["UVit2DModel"], + "vq_model": ["VQModel"], } if TYPE_CHECKING: from .adapter import MultiAdapter, T2IAdapter - from .autoencoders import AutoencoderKL + from .autoencoders import ( + AsymmetricAutoencoderKL, + AutoencoderKL, + AutoencoderKLTemporalDecoder, + AutoencoderTiny, + ConsistencyDecoderVAE, + ) from .controlnet import ControlNetModel from .embeddings import ImageProjection from .modeling_utils import ModelMixin - from .transformers import SD3Transformer2DModel, Transformer2DModel - from .unets import StableCascadeUNet, UNet1DModel, UNet2DConditionModel, UNet2DModel + from .transformers import ( + DualTransformer2DModel, + PriorTransformer, + SD3Transformer2DModel, + T5FilmDecoder, + Transformer2DModel, + TransformerTemporalModel, + ) + from .unets import ( + I2VGenXLUNet, + Kandinsky3UNet, + MotionAdapter, + StableCascadeUNet, + UNet1DModel, + UNet2DConditionModel, + UNet2DModel, + UNet3DConditionModel, + UNetMotionModel, + UNetSpatioTemporalConditionModel, + UVit2DModel, + ) + from .vq_model import VQModel else: import sys diff --git a/mindone/diffusers/models/attention_processor.py b/mindone/diffusers/models/attention_processor.py index 7b798838a7..0a242f736b 100644 --- a/mindone/diffusers/models/attention_processor.py +++ b/mindone/diffusers/models/attention_processor.py @@ -430,7 +430,7 @@ def get_attention_scores(self, query: ms.Tensor, key: ms.Tensor, attention_mask: ) else: attention_scores = ops.baddbmm( - attention_mask, + attention_mask.to(query.dtype), query, key.swapaxes(-1, -2), beta=1, @@ -475,7 +475,9 @@ def prepare_attention_mask( # we want to instead pad by (0, remaining_length), where remaining_length is: # remaining_length: int = target_length - current_length # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding - attention_mask = ops.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = ops.Pad(paddings=((0, 0),) * (attention_mask.ndim - 1) + ((0, target_length),))( + attention_mask + ) if out_dim == 3: if attention_mask.shape[0] < batch_size * head_size: diff --git a/mindone/diffusers/models/autoencoders/__init__.py b/mindone/diffusers/models/autoencoders/__init__.py index eea9710eed..201a40ff17 100644 --- a/mindone/diffusers/models/autoencoders/__init__.py +++ b/mindone/diffusers/models/autoencoders/__init__.py @@ -1 +1,5 @@ +from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_kl import AutoencoderKL +from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder +from .autoencoder_tiny import AutoencoderTiny +from .consistency_decoder_vae import ConsistencyDecoderVAE diff --git a/mindone/diffusers/models/autoencoders/autoencoder_asym_kl.py b/mindone/diffusers/models/autoencoders/autoencoder_asym_kl.py new file mode 100644 index 0000000000..ae2e496c45 --- /dev/null +++ b/mindone/diffusers/models/autoencoders/autoencoder_asym_kl.py @@ -0,0 +1,183 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Tuple, Union + +import numpy as np + +import mindspore as ms +from mindspore import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder + + +class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): + r""" + Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632 . A VAE model with KL loss + for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of down block output channels. + layers_per_down_block (`int`, *optional*, defaults to `1`): + Number layers for down block. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of up block output channels. + layers_per_up_block (`int`, *optional*, defaults to `1`): + Number layers for up block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + norm_num_groups (`int`, *optional*, defaults to `32`): + Number of groups to use for the first normalization layer in ResNet blocks. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), + down_block_out_channels: Tuple[int, ...] = (64,), + layers_per_down_block: int = 1, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + up_block_out_channels: Tuple[int, ...] = (64,), + layers_per_up_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + scaling_factor: float = 0.18215, + ) -> None: + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=down_block_out_channels, + layers_per_block=layers_per_down_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = MaskConditionDecoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=up_block_out_channels, + layers_per_block=layers_per_up_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + ) + self.diag_gauss_dist = DiagonalGaussianDistribution() + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1, has_bias=True) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1, has_bias=True) + + self.use_slicing = False + self.use_tiling = False + + self.register_to_config(block_out_channels=up_block_out_channels) + self.register_to_config(force_upcast=False) + + def encode(self, x: ms.Tensor, return_dict: bool = False) -> Union[AutoencoderKLOutput, Tuple[ms.Tensor]]: + h = self.encoder(x) + moments = self.quant_conv(h) + + if not return_dict: + return (moments,) + + return AutoencoderKLOutput(latent=moments) + + def _decode( + self, + z: ms.Tensor, + image: Optional[ms.Tensor] = None, + mask: Optional[ms.Tensor] = None, + return_dict: bool = False, + ) -> Union[DecoderOutput, Tuple[ms.Tensor]]: + z = self.post_quant_conv(z) + dec = self.decoder(z, image, mask) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def decode( + self, + z: ms.Tensor, + generator: Optional[np.random.Generator] = None, + image: Optional[ms.Tensor] = None, + mask: Optional[ms.Tensor] = None, + return_dict: bool = False, + ) -> Union[DecoderOutput, Tuple[ms.Tensor]]: + decoded = self._decode(z, image, mask)[0] + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def construct( + self, + sample: ms.Tensor, + mask: Optional[ms.Tensor] = None, + sample_posterior: bool = False, + return_dict: bool = False, + ) -> Union[DecoderOutput, Tuple[ms.Tensor]]: + r""" + Args: + sample (`ms.Tensor`): Input sample. + mask (`ms.Tensor`, *optional*, defaults to `None`): Optional inpainting mask. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + latent = self.encode(x)[0] + if sample_posterior: + z = self.diag_gauss_dist.sample(latent) + else: + z = self.diag_gauss_dist.mode(latent) + + dec = self.decode(z, sample, mask)[0] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl.py b/mindone/diffusers/models/autoencoders/autoencoder_kl.py index 27fba56979..3dd38cb224 100644 --- a/mindone/diffusers/models/autoencoders/autoencoder_kl.py +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl.py @@ -144,12 +144,12 @@ def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() - for sub_name, child in module.name_cells(): + for sub_name, child in module.name_cells().items(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors - for name, module in self.name_cells(): + for name, module in self.name_cells().items(): fn_recursive_add_processors(name, module, processors) return processors @@ -183,10 +183,10 @@ def fn_recursive_attn_processor(name: str, module: nn.Cell, processor): else: module.set_processor(processor.pop(f"{name}.processor")) - for sub_name, child in module.name_cells(): + for sub_name, child in module.name_cells().items(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - for name, module in self.name_cells(): + for name, module in self.name_cells().items(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor diff --git a/mindone/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/mindone/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py new file mode 100644 index 0000000000..936d3e169c --- /dev/null +++ b/mindone/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py @@ -0,0 +1,356 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Tuple, Union + +import mindspore as ms +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ..activations import SiLU +from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..normalization import GroupNorm +from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder +from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder + + +class TemporalDecoder(nn.Cell): + def __init__( + self, + in_channels: int = 4, + out_channels: int = 3, + block_out_channels: Tuple[int] = (128, 256, 512, 512), + layers_per_block: int = 2, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[-1], kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True + ) + self.mid_block = MidBlockTemporalDecoder( + num_layers=self.layers_per_block, + in_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + attention_head_dim=block_out_channels[-1], + ) + + # up + self.up_blocks = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + up_block = UpBlockTemporalDecoder( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + self.up_blocks = nn.CellList(self.up_blocks) + + self.conv_norm_out = GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6) + + self.conv_act = SiLU() + self.conv_out = nn.Conv2d( + in_channels=block_out_channels[0], + out_channels=out_channels, + kernel_size=3, + pad_mode="pad", + padding=1, + has_bias=True, + ) + + conv_out_kernel_size = (3, 1, 1) + # padding = [int(k // 2) for k in conv_out_kernel_size] + padding = (1, 1, 0, 0, 0, 0) + self.time_conv_out = nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=conv_out_kernel_size, + pad_mode="pad", + padding=padding, + has_bias=True, + ) + + self.gradient_checkpointing = False + + def construct( + self, + sample: ms.Tensor, + image_only_indicator: ms.Tensor, + num_frames: int = 1, + ) -> ms.Tensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample) + + # middle + sample = self.mid_block(sample, image_only_indicator=image_only_indicator) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, image_only_indicator=image_only_indicator) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + batch_frames, channels, height, width = sample.shape + batch_size = batch_frames // num_frames + sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4) + sample = self.time_conv_out(sample) + + sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width) + + return sample + + +class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + latent_channels: int = 4, + sample_size: int = 32, + scaling_factor: float = 0.18215, + force_upcast: float = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = TemporalDecoder( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1, has_bias=True) + self.diag_gauss_dist = DiagonalGaussianDistribution() + + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder, TemporalDecoder)): + module.gradient_checkpointing = value + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: # type: ignore + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str, AttentionProcessor]): # type: ignore + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.name_cells().items(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.name_cells().items(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): # type: ignore + r""" + Sets the attention processor to use to compute attention. + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: nn.Cell, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.name_cells().items(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.name_cells().items(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def encode( + self, x: ms.Tensor, return_dict: bool = False + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`ms.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + h = self.encoder(x) + moments = self.quant_conv(h) + + if not return_dict: + return (moments,) + + return AutoencoderKLOutput(latent=moments) + + def decode( + self, + z: ms.Tensor, + num_frames: int, + return_dict: bool = False, + ) -> Union[DecoderOutput, ms.Tensor]: + """ + Decode a batch of images. + + Args: + z (`ms.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `False`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + batch_size = z.shape[0] // num_frames + image_only_indicator = ops.zeros((batch_size, num_frames), dtype=z.dtype) + decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def construct( + self, + sample: ms.Tensor, + sample_posterior: bool = False, + return_dict: bool = False, + num_frames: int = 1, + ) -> Union[DecoderOutput, ms.Tensor]: + r""" + Args: + sample (`ms.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + latent = self.encode(x)[0] + if sample_posterior: + z = self.diag_gauss_dist.sample(latent) + else: + z = self.diag_gauss_dist.mode(latent) + + dec = self.decode(z, num_frames=num_frames)[0] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/mindone/diffusers/models/autoencoders/autoencoder_tiny.py b/mindone/diffusers/models/autoencoders/autoencoder_tiny.py new file mode 100644 index 0000000000..09edf20256 --- /dev/null +++ b/mindone/diffusers/models/autoencoders/autoencoder_tiny.py @@ -0,0 +1,209 @@ +# Copyright 2024 Ollin Boer Bohan and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Tuple, Union + +import mindspore as ms +from mindspore import ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..modeling_utils import ModelMixin +from .vae import DecoderOutput, DecoderTiny, EncoderTiny + + +@dataclass +class AutoencoderTinyOutput(BaseOutput): + """ + Output of AutoencoderTiny encoding method. + + Args: + latents (`ms.Tensor`): Encoded outputs of the `Encoder`. + + """ + + latents: ms.Tensor + + +class AutoencoderTiny(ModelMixin, ConfigMixin): + r""" + A tiny distilled VAE model for encoding images into latents and decoding latent representations into images. + + [`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for + all models (such as downloading or saving). + + Parameters: + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. + out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. + encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): + Tuple of integers representing the number of output channels for each encoder block. The length of the + tuple should be equal to the number of encoder blocks. + decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): + Tuple of integers representing the number of output channels for each decoder block. The length of the + tuple should be equal to the number of decoder blocks. + act_fn (`str`, *optional*, defaults to `"relu"`): + Activation function to be used throughout the model. + latent_channels (`int`, *optional*, defaults to 4): + Number of channels in the latent representation. The latent space acts as a compressed representation of + the input image. + upsampling_scaling_factor (`int`, *optional*, defaults to 2): + Scaling factor for upsampling in the decoder. It determines the size of the output image during the + upsampling process. + num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`): + Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The + length of the tuple should be equal to the number of stages in the encoder. Each stage has a different + number of encoder blocks. + num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`): + Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The + length of the tuple should be equal to the number of stages in the decoder. Each stage has a different + number of decoder blocks. + latent_magnitude (`float`, *optional*, defaults to 3.0): + Magnitude of the latent representation. This parameter scales the latent representation values to control + the extent of information preservation. + latent_shift (float, *optional*, defaults to 0.5): + Shift applied to the latent representation. This parameter controls the center of the latent space. + scaling_factor (`float`, *optional*, defaults to 1.0): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder, + however, no such scaling factor was used, hence the value of 1.0 as the default. + force_upcast (`bool`, *optional*, default to `False`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without losing too much precision, in which case + `force_upcast` can be set to `False` (see this fp16-friendly + [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), + decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), + act_fn: str = "relu", + latent_channels: int = 4, + upsampling_scaling_factor: int = 2, + num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3), + num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1), + latent_magnitude: int = 3, + latent_shift: float = 0.5, + force_upcast: bool = False, + scaling_factor: float = 1.0, + ): + super().__init__() + + if len(encoder_block_out_channels) != len(num_encoder_blocks): + raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.") + if len(decoder_block_out_channels) != len(num_decoder_blocks): + raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.") + + self.encoder = EncoderTiny( + in_channels=in_channels, + out_channels=latent_channels, + num_blocks=num_encoder_blocks, + block_out_channels=encoder_block_out_channels, + act_fn=act_fn, + ) + + self.decoder = DecoderTiny( + in_channels=latent_channels, + out_channels=out_channels, + num_blocks=num_decoder_blocks, + block_out_channels=decoder_block_out_channels, + upsampling_scaling_factor=upsampling_scaling_factor, + act_fn=act_fn, + ) + + self.latent_magnitude = latent_magnitude + self.latent_shift = latent_shift + self.scaling_factor = scaling_factor + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.spatial_scale_factor = 2**out_channels + self.tile_overlap_factor = 0.125 + self.tile_sample_min_size = 512 + self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor + + self.register_to_config(block_out_channels=decoder_block_out_channels) + self.register_to_config(force_upcast=False) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (EncoderTiny, DecoderTiny)): + module.gradient_checkpointing = value + + def scale_latents(self, x: ms.Tensor) -> ms.Tensor: + """raw latents -> [0, 1]""" + return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1) + + def unscale_latents(self, x: ms.Tensor) -> ms.Tensor: + """[0, 1] -> raw latents""" + return x.sub(self.latent_shift).mul(2 * self.latent_magnitude) + + def encode(self, x: ms.Tensor, return_dict: bool = False) -> Union[AutoencoderTinyOutput, Tuple[ms.Tensor]]: + output = self.encoder(x) + + if not return_dict: + return (output,) + + return AutoencoderTinyOutput(latents=output) + + def decode(self, x: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, Tuple[ms.Tensor]]: + output = self.decoder(x) + + if not return_dict: + return (output,) + + return DecoderOutput(sample=output) + + def construct( + self, + sample: ms.Tensor, + return_dict: bool = False, + ) -> Union[DecoderOutput, Tuple[ms.Tensor]]: + r""" + Args: + sample (`ms.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + hidden_dtype = sample.dtype + enc = self.encode(sample)[0] + + # scale latents to be in [0, 1], then quantize latents to a byte tensor, + # as if we were storing the latents in an RGBA uint8 image. + scaled_enc = ops.round(self.scale_latents(enc).mul(255)).to(ms.uint8) + + # unquantize latents back into [0, 1], then unscale latents back to their original range, + # as if we were loading the latents from an RGBA uint8 image. + unscaled_enc = self.unscale_latents(scaled_enc / 255.0).to(hidden_dtype) + + # Keep an eye on it: it's different from diffusers: ...[0] + dec = self.decode(unscaled_enc)[0] + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) diff --git a/mindone/diffusers/models/autoencoders/consistency_decoder_vae.py b/mindone/diffusers/models/autoencoders/consistency_decoder_vae.py new file mode 100644 index 0000000000..a5f0ee88d5 --- /dev/null +++ b/mindone/diffusers/models/autoencoders/consistency_decoder_vae.py @@ -0,0 +1,309 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import numpy as np + +import mindspore as ms +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...schedulers import ConsistencyDecoderScheduler +from ...utils import BaseOutput +from ...utils.mindspore_utils import randn_tensor +from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from ..modeling_utils import ModelMixin +from ..unets.unet_2d import UNet2DModel +from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder + + +@dataclass +class ConsistencyDecoderVAEOutput(BaseOutput): + """ + Output of encoding method. + + Args: + latent_dist (`DiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. + `DiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent: ms.Tensor + + +class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): + r""" + The consistency decoder used with DALL-E 3. + + Examples: + ```py + >>> import mindspore + >>> from mindone.diffusers import StableDiffusionPipeline, ConsistencyDecoderVAE + + >>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", mindspore_dtype=mindspore.float16) + >>> pipe = StableDiffusionPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", vae=vae, mindspore_dtype=mindspore.float16 + ... ) + + >>> pipe("horse").images + ``` + """ + + @register_to_config + def __init__( + self, + scaling_factor: float = 0.18215, + latent_channels: int = 4, + encoder_act_fn: str = "silu", + encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + encoder_double_z: bool = True, + encoder_down_block_types: Tuple[str, ...] = ( + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ), + encoder_in_channels: int = 3, + encoder_layers_per_block: int = 2, + encoder_norm_num_groups: int = 32, + encoder_out_channels: int = 4, + decoder_add_attention: bool = False, + decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024), + decoder_down_block_types: Tuple[str, ...] = ( + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + ), + decoder_downsample_padding: int = 1, + decoder_in_channels: int = 7, + decoder_layers_per_block: int = 3, + decoder_norm_eps: float = 1e-05, + decoder_norm_num_groups: int = 32, + decoder_num_train_timesteps: int = 1024, + decoder_out_channels: int = 6, + decoder_resnet_time_scale_shift: str = "scale_shift", + decoder_time_embedding_type: str = "learned", + decoder_up_block_types: Tuple[str, ...] = ( + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ), + ): + super().__init__() + self.encoder = Encoder( + act_fn=encoder_act_fn, + block_out_channels=encoder_block_out_channels, + double_z=encoder_double_z, + down_block_types=encoder_down_block_types, + in_channels=encoder_in_channels, + layers_per_block=encoder_layers_per_block, + norm_num_groups=encoder_norm_num_groups, + out_channels=encoder_out_channels, + ) + + self.decoder_unet = UNet2DModel( + add_attention=decoder_add_attention, + block_out_channels=decoder_block_out_channels, + down_block_types=decoder_down_block_types, + downsample_padding=decoder_downsample_padding, + in_channels=decoder_in_channels, + layers_per_block=decoder_layers_per_block, + norm_eps=decoder_norm_eps, + norm_num_groups=decoder_norm_num_groups, + num_train_timesteps=decoder_num_train_timesteps, + out_channels=decoder_out_channels, + resnet_time_scale_shift=decoder_resnet_time_scale_shift, + time_embedding_type=decoder_time_embedding_type, + up_block_types=decoder_up_block_types, + ) + self.diag_gauss_dist = DiagonalGaussianDistribution() + self.decoder_scheduler = ConsistencyDecoderScheduler() + self.register_to_config(block_out_channels=encoder_block_out_channels) + self.register_to_config(force_upcast=False) + self.means = ms.Tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None] + self.stds = ms.Tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None] + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1, has_bias=True) + + self.use_slicing = False + self.use_tiling = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: # type: ignore + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str, AttentionProcessor]): # type: ignore + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.name_cells().items(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.name_cells().items(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): # type: ignore + r""" + Sets the attention processor to use to compute attention. + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: nn.Cell, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.name_cells().items(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.name_cells().items(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def encode( + self, x: ms.Tensor, return_dict: bool = False + ) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`ms.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain + tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a plain `tuple` + is returned. + """ + h = self.encoder(x) + + moments = self.quant_conv(h) + + if not return_dict: + return (moments,) + + return ConsistencyDecoderVAEOutput(latent=moments) + + def decode( + self, + z: ms.Tensor, + generator: Optional[np.random.Generator] = None, + return_dict: bool = False, + num_inference_steps: int = 2, + ) -> Union[DecoderOutput, Tuple[ms.Tensor]]: + z = ((z * self.config["scaling_factor"] - self.means) / self.stds).to(z.dtype) + + scale_factor = 2 ** (len(self.config["block_out_channels"]) - 1) + z = ops.interpolate(z, mode="nearest", size=(z.shape[-2] * scale_factor, z.shape[-1] * scale_factor)) + + batch_size, _, height, width = z.shape + + # self.decoder_scheduler.set_timesteps(num_inference_steps) + + x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor( + (batch_size, 3, height, width), + generator=generator, + dtype=z.dtype, + ) + + for t in self.decoder_scheduler.timesteps: + model_input = ops.concat([self.decoder_scheduler.scale_model_input(x_t, t).to(z.dtype), z], axis=1) + model_output = self.decoder_unet(model_input, t)[0][:, :3, :, :] + prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator)[0] + x_t = prev_sample + + x_0 = x_t + + if not return_dict: + return (x_0,) + + return DecoderOutput(sample=x_0) + + def construct( + self, + sample: ms.Tensor, + sample_posterior: bool = False, + return_dict: bool = False, + generator: Optional[np.random.Generator] = None, + ) -> Union[DecoderOutput, Tuple[ms.Tensor]]: + r""" + Args: + sample (`ms.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + generator (`np.random.Generator`, *optional*, defaults to `None`): + Generator to use for sampling. + + Returns: + [`DecoderOutput`] or `tuple`: + If return_dict is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned. + """ + x = sample + latent = self.encode(x)[0] + if sample_posterior: + z = self.diag_gauss_dist.sample(latent) + else: + z = self.diag_gauss_dist.mode(latent) + dec = self.decode(z, generator=generator)[0] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/mindone/diffusers/models/autoencoders/vae.py b/mindone/diffusers/models/autoencoders/vae.py index 45f5f41687..2161fb7b3c 100644 --- a/mindone/diffusers/models/autoencoders/vae.py +++ b/mindone/diffusers/models/autoencoders/vae.py @@ -20,10 +20,10 @@ from mindspore import nn, ops from ...utils import BaseOutput -from ..activations import SiLU +from ..activations import SiLU, get_activation from ..attention_processor import SpatialNorm from ..normalization import GroupNorm -from ..unets.unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block +from ..unets.unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block @dataclass @@ -306,6 +306,373 @@ def construct( return sample +class UpSample(nn.Cell): + r""" + The `UpSample` layer of a variational autoencoder that upsamples its input. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.deconv = nn.Conv2dTranspose( + in_channels, out_channels, kernel_size=4, stride=2, pad_mode="pad", padding=1, has_bias=True + ) + + def construct(self, x: ms.Tensor) -> ms.Tensor: + r"""The forward method of the `UpSample` class.""" + x = ops.relu(x) + x = self.deconv(x) + return x + + +class MaskConditionEncoder(nn.Cell): + """ + used in AsymmetricAutoencoderKL + """ + + def __init__( + self, + in_ch: int, + out_ch: int = 192, + res_ch: int = 768, + stride: int = 16, + ) -> None: + super().__init__() + + channels = [] + while stride > 1: + stride = stride // 2 + in_ch_ = out_ch * 2 + if out_ch > res_ch: + out_ch = res_ch + if stride == 1: + in_ch_ = res_ch + channels.append((in_ch_, out_ch)) + out_ch *= 2 + + out_channels = [] + for _in_ch, _out_ch in channels: + out_channels.append(_out_ch) + out_channels.append(channels[-1][0]) + + layers = [] + in_ch_ = in_ch + for l in range(len(out_channels)): # noqa: E741 + out_ch_ = out_channels[l] + if l == 0 or l == 1: # noqa: E741 + layers.append( + nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, pad_mode="pad", padding=1, has_bias=True) + ) + else: + layers.append( + nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, pad_mode="pad", padding=1, has_bias=True) + ) + in_ch_ = out_ch_ + + # nn.SequentialCell does not support the len(self.layers) method nor does it support self.layers[l], + # therefore, we use nn.CellList instead. + self.layers = nn.CellList(layers) + + def construct(self, x: ms.Tensor, mask=None) -> ms.Tensor: + r"""The forward method of the `MaskConditionEncoder` class.""" + out = {} + for l in range(len(self.layers)): # noqa: E741 + layer = self.layers[l] + x = layer(x) + out[str(tuple(x.shape))] = x + x = ops.relu(x) + return out + + +class MaskConditionDecoder(nn.Cell): + r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's + decoder with a conditioner on the mask and masked image. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + pad_mode="pad", + padding=1, + has_bias=True, + ) + + self.mid_block = None + self.up_blocks = [] + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + self.up_blocks = nn.CellList(self.up_blocks) + + # condition encoder + self.condition_encoder = MaskConditionEncoder( + in_ch=out_channels, + out_ch=block_out_channels[0], + res_ch=block_out_channels[-1], + ) + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, pad_mode="pad", padding=1, has_bias=True) + + self.gradient_checkpointing = False + + def construct( + self, + z: ms.Tensor, + image: Optional[ms.Tensor] = None, + mask: Optional[ms.Tensor] = None, + latent_embeds: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + r"""The forward method of the `MaskConditionDecoder` class.""" + sample = z + sample = self.conv_in(sample) + + # middle + sample = self.mid_block(sample, latent_embeds) + + # condition encoder + im_x = None + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = self.condition_encoder(masked_image, mask) + + # up + for up_block in self.up_blocks: + if image is not None and mask is not None: + sample_ = im_x[str(tuple(sample.shape))] + mask_ = ops.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = up_block(sample, latent_embeds) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class VectorQuantizer(nn.Cell): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__( + self, + n_e: int, + vq_embed_dim: int, + beta: float, + remap=None, + unknown_index: str = "random", + sane_index_shape: bool = False, + legacy: bool = True, + ): + super().__init__() + self.n_e = n_e + self.vq_embed_dim = vq_embed_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding( + self.n_e, + self.vq_embed_dim, + embedding_table=ops.uniform( + (self.n_e, self.vq_embed_dim), ms.Tensor(-1.0 / self.n_e), ms.Tensor(1.0 / self.n_e) + ), + ) + + self.remap = remap + if self.remap is not None: + self.used = ms.Tensor(np.load(self.remap)) + self.used: ms.Tensor + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds: ms.Tensor) -> ms.Tensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds.dtype) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = ops.randint(0, self.re_embed, size=new[unknown].shape) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds: ms.Tensor) -> ms.Tensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds.dtype) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = ops.gather_elements(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def construct(self, z: ms.Tensor) -> Tuple[ms.Tensor, ms.Tensor, Tuple]: + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1) + z_flattened = z.view(-1, self.vq_embed_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + # ops.cdist caused unexpected error on NPU, use equivalent implement + cdist = ops.pow(z_flattened[:, None, :] - self.embedding.embedding_table[None, ...], 2).mean(axis=-1) + min_encoding_indices = ops.argmin(cdist, axis=1) + + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * ops.mean((ops.stop_gradient(z_q) - z) ** 2) + ops.mean((z_q - ops.stop_gradient(z)) ** 2) + else: + loss = ops.mean((ops.stop_gradient(z_q) - z) ** 2) + self.beta * ops.mean((z_q - ops.stop_gradient(z)) ** 2) + + # preserve gradients + z_q = z + ops.stop_gradient(z_q - z) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2) + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices: ms.Tensor, shape: Tuple[int, ...]) -> ms.Tensor: + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2) + + return z_q + + @ms.jit_class class DiagonalGaussianDistribution(object): def __init__(self, deterministic: bool = False): @@ -367,3 +734,166 @@ def nll(self, parameters: ms.Tensor, sample: ms.Tensor, dims: Tuple[int, ...] = def mode(self, parameters: ms.Tensor) -> ms.Tensor: mean, logvar, var, std = self.init(parameters) return mean + + +class EncoderTiny(nn.Cell): + r""" + The `EncoderTiny` layer is a simpler version of the `Encoder` layer. + + Args: + in_channels (`int`): + The number of input channels. + out_channels (`int`): + The number of output channels. + num_blocks (`Tuple[int, ...]`): + Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to + use. + block_out_channels (`Tuple[int, ...]`): + The number of output channels for each block. + act_fn (`str`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: Tuple[int, ...], + block_out_channels: Tuple[int, ...], + act_fn: str, + ): + super().__init__() + + layers = [] + for i, num_block in enumerate(num_blocks): + num_channels = block_out_channels[i] + + if i == 0: + layers.append( + nn.Conv2d(in_channels, num_channels, kernel_size=3, pad_mode="pad", padding=1, has_bias=True) + ) + else: + layers.append( + nn.Conv2d( + num_channels, + num_channels, + kernel_size=3, + pad_mode="pad", + padding=1, + stride=2, + has_bias=False, + ) + ) + + for _ in range(num_block): + layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) + + layers.append( + nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, pad_mode="pad", padding=1, has_bias=True) + ) + + self.layers = nn.SequentialCell(*layers) + self.gradient_checkpointing = False + + def construct(self, x: ms.Tensor) -> ms.Tensor: + r"""The forward method of the `EncoderTiny` class.""" + # scale image from [-1, 1] to [0, 1] to match TAESD convention + x = self.layers(x.add(1).div(2)) + + return x + + +class DecoderTiny(nn.Cell): + r""" + The `DecoderTiny` layer is a simpler version of the `Decoder` layer. + + Args: + in_channels (`int`): + The number of input channels. + out_channels (`int`): + The number of output channels. + num_blocks (`Tuple[int, ...]`): + Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to + use. + block_out_channels (`Tuple[int, ...]`): + The number of output channels for each block. + upsampling_scaling_factor (`int`): + The scaling factor to use for upsampling. + act_fn (`str`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: Tuple[int, ...], + block_out_channels: Tuple[int, ...], + upsampling_scaling_factor: int, + act_fn: str, + ): + super().__init__() + + layers = [ + nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, pad_mode="pad", padding=1, has_bias=True), + get_activation(act_fn)(), + ] + + for i, num_block in enumerate(num_blocks): + is_final_block = i == (len(num_blocks) - 1) + num_channels = block_out_channels[i] + + for _ in range(num_block): + layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) + + if not is_final_block: + layers.append(DecoderTinyUpsample(scale_factor=upsampling_scaling_factor)) + + conv_out_channel = num_channels if not is_final_block else out_channels + layers.append( + nn.Conv2d( + num_channels, + conv_out_channel, + kernel_size=3, + pad_mode="pad", + padding=1, + has_bias=is_final_block, + ) + ) + + self.layers = nn.SequentialCell(*layers) + self.gradient_checkpointing = False + + def construct(self, x: ms.Tensor) -> ms.Tensor: + r"""The forward method of the `DecoderTiny` class.""" + # Clamp. + x = ops.tanh(x / 3) * 3 + x = self.layers(x) + + # scale image from [0, 1] to [-1, 1] to match diffusers convention + return x.mul(2).sub(1) + + +class DecoderTinyUpsample(nn.Cell): + def __init__(self, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None): + """ + This class provides an equivalent implementation of the `nn.Upsample`. The native `nn.Upsample` relies on + `ops.interpolate`, which encounters limitations when handling the `scale_factor` parameter in forward. + Instead this class uses the `size` argument to control output tensor shape. + """ + super().__init__() + self.size = size + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + self.recompute_scale_factor = recompute_scale_factor + + def construct(self, x): + if self.size is None: + assert self.scale_factor is not None + size = (int(x.shape[-2] * self.scale_factor), int(x.shape[-1] * self.scale_factor)) + else: + size = self.size + + out = ops.interpolate(x, size, None, self.mode, self.align_corners, self.recompute_scale_factor) + return out diff --git a/mindone/diffusers/models/controlnet.py b/mindone/diffusers/models/controlnet.py index eb09275d3b..b28d137ad9 100644 --- a/mindone/diffusers/models/controlnet.py +++ b/mindone/diffusers/models/controlnet.py @@ -18,6 +18,7 @@ from mindspore import nn, ops from ..configuration_utils import ConfigMixin, register_to_config +from ..loaders import FromOriginalControlNetMixin from ..utils import BaseOutput, logging from .attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps @@ -112,7 +113,7 @@ def construct(self, conditioning): return embedding -class ControlNetModel(ModelMixin, ConfigMixin): +class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin): """ A ControlNet model. @@ -549,12 +550,12 @@ def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() - for sub_name, child in module.name_cells(): + for sub_name, child in module.name_cells().items(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors - for name, module in self.name_cells(): + for name, module in self.name_cells().items(): fn_recursive_add_processors(name, module, processors) return processors @@ -585,10 +586,10 @@ def fn_recursive_attn_processor(name: str, module: nn.Cell, processor): else: module.set_processor(processor.pop(f"{name}.processor")) - for sub_name, child in module.name_cells(): + for sub_name, child in module.name_cells().items(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - for name, module in self.name_cells(): + for name, module in self.name_cells().items(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor @@ -733,7 +734,7 @@ def construct( f"which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") - time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = self.add_time_proj(time_ids.flatten()).to(time_ids.dtype) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = ops.concat([text_embeds, time_embeds], axis=-1) diff --git a/mindone/diffusers/models/embeddings.py b/mindone/diffusers/models/embeddings.py index 655a3d331f..4f648b74c4 100644 --- a/mindone/diffusers/models/embeddings.py +++ b/mindone/diffusers/models/embeddings.py @@ -621,19 +621,19 @@ def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): self.image_norm = LayerNorm(time_embed_dim) self.input_hint_block = nn.SequentialCell( nn.Conv2d(3, 16, 3, pad_mode="pad", padding=1, has_bias=True), - nn.SiLU(), + SiLU(), nn.Conv2d(16, 16, 3, pad_mode="pad", padding=1, has_bias=True), - nn.SiLU(), + SiLU(), nn.Conv2d(16, 32, 3, pad_mode="pad", padding=1, stride=2, has_bias=True), - nn.SiLU(), + SiLU(), nn.Conv2d(32, 32, 3, pad_mode="pad", padding=1, has_bias=True), - nn.SiLU(), + SiLU(), nn.Conv2d(32, 96, 3, pad_mode="pad", padding=1, stride=2, has_bias=True), - nn.SiLU(), + SiLU(), nn.Conv2d(96, 96, 3, pad_mode="pad", padding=1, has_bias=True), - nn.SiLU(), + SiLU(), nn.Conv2d(96, 256, 3, pad_mode="pad", padding=1, stride=2, has_bias=True), - nn.SiLU(), + SiLU(), nn.Conv2d(256, 4, 3, pad_mode="pad", padding=1, has_bias=True), ) @@ -709,7 +709,7 @@ def get_fourier_embeds_from_boundingbox(embed_dim, box): batch_size, num_boxes = box.shape[:2] - emb = 100 ** (ops.arange(embed_dim) / embed_dim) + emb = 100 ** (ops.arange(embed_dim).to(dtype=box.dtype) / embed_dim) emb = emb[None, None, None].to(dtype=box.dtype) emb = emb * box.unsqueeze(-1) @@ -734,9 +734,9 @@ def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freq if feature_type == "text-only": self.linears = nn.SequentialCell( nn.Dense(self.positive_len + self.position_dim, 512), - nn.SiLU(), + SiLU(), nn.Dense(512, 512), - nn.SiLU(), + SiLU(), nn.Dense(512, out_dim), ) self.null_positive_feature = ms.Parameter(ops.zeros([self.positive_len]), name="null_positive_feature") @@ -744,16 +744,16 @@ def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freq elif feature_type == "text-image": self.linears_text = nn.SequentialCell( nn.Dense(self.positive_len + self.position_dim, 512), - nn.SiLU(), + SiLU(), nn.Dense(512, 512), - nn.SiLU(), + SiLU(), nn.Dense(512, out_dim), ) self.linears_image = nn.SequentialCell( nn.Dense(self.positive_len + self.position_dim, 512), - nn.SiLU(), + SiLU(), nn.Dense(512, 512), - nn.SiLU(), + SiLU(), nn.Dense(512, out_dim), ) self.null_text_feature = ms.Parameter(ops.zeros([self.positive_len]), name="null_text_feature") diff --git a/mindone/diffusers/models/resnet.py b/mindone/diffusers/models/resnet.py index 31001c0109..f3a6efd0ce 100644 --- a/mindone/diffusers/models/resnet.py +++ b/mindone/diffusers/models/resnet.py @@ -18,7 +18,7 @@ import mindspore as ms from mindspore import nn, ops -from .activations import get_activation +from .activations import SiLU, get_activation from .attention_processor import SpatialNorm from .downsampling import Downsample1D, Downsample2D, FirDownsample2D, KDownsample2D, downsample_2d # noqa from .normalization import AdaGroupNorm, GroupNorm @@ -487,24 +487,24 @@ def __init__( # conv layers self.conv1 = nn.SequentialCell( GroupNorm(norm_num_groups, in_dim), - nn.SiLU(), + SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 1, 0, 0, 0, 0), pad_mode="pad", has_bias=True), ) self.conv2 = nn.SequentialCell( GroupNorm(norm_num_groups, out_dim), - nn.SiLU(), + SiLU(), nn.Dropout(p=dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 1, 0, 0, 0, 0), pad_mode="pad", has_bias=True), ) self.conv3 = nn.SequentialCell( GroupNorm(norm_num_groups, out_dim), - nn.SiLU(), + SiLU(), nn.Dropout(p=dropout), nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 1, 0, 0, 0, 0), pad_mode="pad", has_bias=True), ) self.conv4 = nn.SequentialCell( GroupNorm(norm_num_groups, out_dim), - nn.SiLU(), + SiLU(), nn.Dropout(p=dropout), nn.Conv3d( out_dim, diff --git a/mindone/diffusers/models/transformers/__init__.py b/mindone/diffusers/models/transformers/__init__.py index 1c0a2d4662..ec011ad544 100644 --- a/mindone/diffusers/models/transformers/__init__.py +++ b/mindone/diffusers/models/transformers/__init__.py @@ -1,2 +1,6 @@ +from .dual_transformer_2d import DualTransformer2DModel +from .prior_transformer import PriorTransformer +from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel from .transformer_sd3 import SD3Transformer2DModel +from .transformer_temporal import TransformerTemporalModel diff --git a/mindone/diffusers/models/transformers/dual_transformer_2d.py b/mindone/diffusers/models/transformers/dual_transformer_2d.py new file mode 100644 index 0000000000..baa00e3013 --- /dev/null +++ b/mindone/diffusers/models/transformers/dual_transformer_2d.py @@ -0,0 +1,155 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from mindspore import nn + +from .transformer_2d import Transformer2DModel, Transformer2DModelOutput + + +class DualTransformer2DModel(nn.Cell): + """ + Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + ): + super().__init__() + self.transformers = nn.CellList( + [ + Transformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + ) + for _ in range(2) + ] + ) + + # Variables that can be set by a pipeline: + + # The ratio of transformer1 to transformer2's output states to be combined during inference + self.mix_ratio = 0.5 + + # The shape of `encoder_hidden_states` is expected to be + # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` + self.condition_lengths = [77, 257] + + # Which transformer to use to encode which condition. + # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` + self.transformer_index_for_condition = [1, 0] + + def construct( + self, + hidden_states, + encoder_hidden_states, + timestep=None, + attention_mask=None, + cross_attention_kwargs=None, + return_dict: bool = False, + ): + """ + Args: + hidden_states ( When discrete, `ms.Tensor` of shape `(batch size, num latent pixels)`. + When continuous, `ms.Tensor` of shape `(batch size, channel, height, width)`): Input + hidden_states. + encoder_hidden_states ( `ms.Tensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `ms.Tensor`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + attention_mask (`ms.Tensor`, *optional*): + Optional attention mask to be applied in Attention. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + input_states = hidden_states + + encoded_states = [] + tokens_start = 0 + # attention_mask is not used yet + for i in range(2): + # for each of the two transformers, pass the corresponding condition tokens + condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] + transformer_index = self.transformer_index_for_condition[i] + encoded_state = self.transformers[transformer_index]( + input_states, + encoder_hidden_states=condition_state, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + encoded_states.append(encoded_state - input_states) + tokens_start += self.condition_lengths[i] + + output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) + output_states = output_states + input_states + + if not return_dict: + return (output_states,) + + return Transformer2DModelOutput(sample=output_states) diff --git a/mindone/diffusers/models/transformers/prior_transformer.py b/mindone/diffusers/models/transformers/prior_transformer.py new file mode 100644 index 0000000000..1de28c0482 --- /dev/null +++ b/mindone/diffusers/models/transformers/prior_transformer.py @@ -0,0 +1,368 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Union + +import mindspore as ms +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from ...utils import BaseOutput +from ..attention import BasicTransformerBlock +from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..normalization import LayerNorm + + +@dataclass +class PriorTransformerOutput(BaseOutput): + """ + The output of [`PriorTransformer`]. + + Args: + predicted_image_embedding (`ms.Tensor` of shape `(batch_size, embedding_dim)`): + The predicted CLIP image embedding conditioned on the CLIP text embedding input. + """ + + predicted_image_embedding: ms.Tensor + + +class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + """ + A Prior Transformer model. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. + embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states` + num_embeddings (`int`, *optional*, defaults to 77): + The number of embeddings of the model input `hidden_states` + additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the + projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + + additional_embeddings`. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + time_embed_act_fn (`str`, *optional*, defaults to 'silu'): + The activation function to use to create timestep embeddings. + norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before + passing to Transformer blocks. Set it to `None` if normalization is not needed. + embedding_proj_norm_type (`str`, *optional*, defaults to None): + The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not + needed. + encoder_hid_proj_type (`str`, *optional*, defaults to `linear`): + The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if + `encoder_hidden_states` is `None`. + added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model. + Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot + product between the text embedding and image embedding as proposed in the unclip paper + https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended. + time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings. + If None, will be set to `num_attention_heads * attention_head_dim` + embedding_proj_dim (`int`, *optional*, default to None): + The dimension of `proj_embedding`. If None, will be set to `embedding_dim`. + clip_embed_dim (`int`, *optional*, default to None): + The dimension of the output. If None, will be set to `embedding_dim`. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 32, + attention_head_dim: int = 64, + num_layers: int = 20, + embedding_dim: int = 768, + num_embeddings=77, + additional_embeddings=4, + dropout: float = 0.0, + time_embed_act_fn: str = "silu", + norm_in_type: Optional[str] = None, # layer + embedding_proj_norm_type: Optional[str] = None, # layer + encoder_hid_proj_type: Optional[str] = "linear", # linear + added_emb_type: Optional[str] = "prd", # prd + time_embed_dim: Optional[int] = None, + embedding_proj_dim: Optional[int] = None, + clip_embed_dim: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.additional_embeddings = additional_embeddings + + time_embed_dim = time_embed_dim or inner_dim + embedding_proj_dim = embedding_proj_dim or embedding_dim + clip_embed_dim = clip_embed_dim or embedding_dim + + self.time_proj = Timesteps(inner_dim, True, 0) + self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn) + + self.proj_in = nn.Dense(embedding_dim, inner_dim) + + if embedding_proj_norm_type is None: + self.embedding_proj_norm = None + elif embedding_proj_norm_type == "layer": + self.embedding_proj_norm = LayerNorm(embedding_proj_dim) + else: + raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}") + + self.embedding_proj = nn.Dense(embedding_proj_dim, inner_dim) + + if encoder_hid_proj_type is None: + self.encoder_hidden_states_proj = None + elif encoder_hid_proj_type == "linear": + self.encoder_hidden_states_proj = nn.Dense(embedding_dim, inner_dim) + else: + raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}") + + self.positional_embedding = ms.Parameter( + ops.zeros((1, num_embeddings + additional_embeddings, inner_dim)), name="positional_embedding" + ) + + if added_emb_type == "prd": + self.prd_embedding = ms.Parameter(ops.zeros((1, 1, inner_dim)), name="prd_embedding") + elif added_emb_type is None: + self.prd_embedding = None + else: + raise ValueError( + f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`." + ) + + self.transformer_blocks = nn.CellList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + activation_fn="gelu", + attention_bias=True, + ) + for d in range(num_layers) + ] + ) + + if norm_in_type == "layer": + self.norm_in = LayerNorm(inner_dim) + elif norm_in_type is None: + self.norm_in = None + else: + raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.") + + self.norm_out = LayerNorm(inner_dim) + + self.proj_to_clip_embeddings = nn.Dense(inner_dim, clip_embed_dim) + + causal_attention_mask = ops.full( + [num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0 + ) + causal_attention_mask = causal_attention_mask.triu(1) + causal_attention_mask = causal_attention_mask[None, ...] + self.causal_attention_mask = causal_attention_mask + + self.clip_mean = ms.Parameter(ops.zeros((1, clip_embed_dim)), name="clip_mean") + self.clip_std = ms.Parameter(ops.zeros((1, clip_embed_dim)), name="clip_std") + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: # type: ignore + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str, AttentionProcessor]): # type: ignore + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.name_cells().items(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.name_cells().items(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): # type: ignore + r""" + Sets the attention processor to use to compute attention. + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: nn.Cell, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.name_cells().items(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.name_cells().items(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def construct( + self, + hidden_states, + timestep: Union[ms.Tensor, float, int], + proj_embedding: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + return_dict: bool = False, + ): + """ + The [`PriorTransformer`] forward method. + + Args: + hidden_states (`ms.Tensor` of shape `(batch_size, embedding_dim)`): + The currently predicted image embeddings. + timestep (`ms.Tensor`): + Current denoising step. + proj_embedding (`ms.Tensor` of shape `(batch_size, embedding_dim)`): + Projected embedding vector the denoising process is conditioned on. + encoder_hidden_states (`ms.Tensor` of shape `(batch_size, num_embeddings, embedding_dim)`): + Hidden states of the text embeddings the denoising process is conditioned on. + attention_mask (`ms.Tensor` of shape `(batch_size, num_embeddings)`): + Text mask for the text embeddings. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain + tuple. + + Returns: + [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: + If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + batch_size = hidden_states.shape[0] + + timesteps = timestep + if not ops.is_tensor(timesteps): + timesteps = ms.Tensor([timesteps], dtype=ms.int64) + elif ops.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None] + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps * ops.ones(batch_size, dtype=timesteps.dtype) + + timesteps_projected = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timesteps_projected = timesteps_projected.to(dtype=self.dtype) + time_embeddings = self.time_embedding(timesteps_projected) + + if self.embedding_proj_norm is not None: + proj_embedding = self.embedding_proj_norm(proj_embedding) + + proj_embeddings = self.embedding_proj(proj_embedding) + if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None: + encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states) + elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None: + raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set") + + hidden_states = self.proj_in(hidden_states) + + positional_embeddings = self.positional_embedding.to(hidden_states.dtype) + + additional_embeds = [] + additional_embeddings_len = 0 + + if encoder_hidden_states is not None: + additional_embeds.append(encoder_hidden_states) + additional_embeddings_len += encoder_hidden_states.shape[1] + + if len(proj_embeddings.shape) == 2: + proj_embeddings = proj_embeddings[:, None, :] + + if len(hidden_states.shape) == 2: + hidden_states = hidden_states[:, None, :] + + additional_embeds = additional_embeds + [ + proj_embeddings, + time_embeddings[:, None, :], + hidden_states, + ] + + if self.prd_embedding is not None: + prd_embedding = self.prd_embedding.to(hidden_states.dtype).broadcast_to((batch_size, -1, -1)) + additional_embeds.append(prd_embedding) + + hidden_states = ops.cat( + additional_embeds, + axis=1, + ) + + # Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens + additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1 + if positional_embeddings.shape[1] < hidden_states.shape[1]: + positional_embeddings = ops.Pad( + ( + (0, 0), + (additional_embeddings_len, self.prd_embedding.shape[1] if self.prd_embedding is not None else 0), + (0, 0), + ) + )(positional_embeddings) + + hidden_states = hidden_states + positional_embeddings + + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = ops.Pad(((0, 0), (0, self.additional_embeddings)))(attention_mask) + attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype) + attention_mask = attention_mask.repeat_interleave(self.config["num_attention_heads"], dim=0) + + if self.norm_in is not None: + hidden_states = self.norm_in(hidden_states) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states, attention_mask=attention_mask) + + hidden_states = self.norm_out(hidden_states) + + if self.prd_embedding is not None: + hidden_states = hidden_states[:, -1] + else: + hidden_states = hidden_states[:, additional_embeddings_len:] + + predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states) + + if not return_dict: + return (predicted_image_embedding,) + + return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding) + + def post_process_latents(self, prior_latents): + prior_latents = (prior_latents * self.clip_std) + self.clip_mean + return prior_latents diff --git a/mindone/diffusers/models/transformers/t5_film_transformer.py b/mindone/diffusers/models/transformers/t5_film_transformer.py new file mode 100644 index 0000000000..36cea32110 --- /dev/null +++ b/mindone/diffusers/models/transformers/t5_film_transformer.py @@ -0,0 +1,437 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional, Tuple + +import mindspore as ms +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ..activations import SiLU +from ..attention_processor import Attention +from ..embeddings import get_timestep_embedding +from ..modeling_utils import ModelMixin + + +class T5FilmDecoder(ModelMixin, ConfigMixin): + r""" + T5 style decoder with FiLM conditioning. + + Args: + input_dims (`int`, *optional*, defaults to `128`): + The number of input dimensions. + targets_length (`int`, *optional*, defaults to `256`): + The length of the targets. + d_model (`int`, *optional*, defaults to `768`): + Size of the input hidden states. + num_layers (`int`, *optional*, defaults to `12`): + The number of `DecoderLayer`'s to use. + num_heads (`int`, *optional*, defaults to `12`): + The number of attention heads to use. + d_kv (`int`, *optional*, defaults to `64`): + Size of the key-value projection vectors. + d_ff (`int`, *optional*, defaults to `2048`): + The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s. + dropout_rate (`float`, *optional*, defaults to `0.1`): + Dropout probability. + """ + + @register_to_config + def __init__( + self, + input_dims: int = 128, + targets_length: int = 256, + max_decoder_noise_time: float = 2000.0, + d_model: int = 768, + num_layers: int = 12, + num_heads: int = 12, + d_kv: int = 64, + d_ff: int = 2048, + dropout_rate: float = 0.1, + ): + super().__init__() + + self.conditioning_emb = nn.SequentialCell( + nn.Dense(d_model, d_model * 4, has_bias=False), + SiLU(), + nn.Dense(d_model * 4, d_model * 4, has_bias=False), + SiLU(), + ) + + self.position_encoding = nn.Embedding(targets_length, d_model) + self.position_encoding.embedding_table.requires_grad = False + + self.continuous_inputs_projection = nn.Dense(input_dims, d_model, has_bias=False) + + self.dropout = nn.Dropout(p=dropout_rate) + + self.decoders = [] + for lyr_num in range(num_layers): + # FiLM conditional T5 decoder + lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate) + self.decoders.append(lyr) + self.decoders = nn.CellList(self.decoders) + + self.decoder_norm = T5LayerNorm(d_model) + + self.post_dropout = nn.Dropout(p=dropout_rate) + self.spec_out = nn.Dense(d_model, input_dims, has_bias=False) + + def encoder_decoder_mask(self, query_input: ms.Tensor, key_input: ms.Tensor) -> ms.Tensor: + mask = ops.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2)) + return mask.unsqueeze(-3) + + def construct(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time): + batch, _, _ = decoder_input_tokens.shape + assert decoder_noise_time.shape == (batch,) + + # decoder_noise_time is in [0, 1), so rescale to expected timing range. + time_steps = get_timestep_embedding( + decoder_noise_time * self.config["max_decoder_noise_time"], + embedding_dim=self.config["d_model"], + max_period=self.config["max_decoder_noise_time"], + ).to(dtype=self.dtype) + + conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1) + + assert conditioning_emb.shape == (batch, 1, self.config["d_model"] * 4) + + seq_length = decoder_input_tokens.shape[1] + + # If we want to use relative positions for audio context, we can just offset + # this sequence by the length of encodings_and_masks. + decoder_positions = ops.broadcast_to(ops.arange(seq_length), (batch, seq_length)) + + position_encodings = self.position_encoding(decoder_positions) + + inputs = self.continuous_inputs_projection(decoder_input_tokens) + inputs += position_encodings + y = self.dropout(inputs) + + # decoder: No padding present. + decoder_mask = ops.ones(decoder_input_tokens.shape[:2], dtype=inputs.dtype) + + # Translate encoding masks to encoder-decoder masks. + encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks] + + # cross attend style: concat encodings + encoded = ops.cat([x[0] for x in encodings_and_encdec_masks], axis=1) + encoder_decoder_mask = ops.cat([x[1] for x in encodings_and_encdec_masks], axis=-1) + + for lyr in self.decoders: + y = lyr( + y, + conditioning_emb=conditioning_emb, + encoder_hidden_states=encoded, + encoder_attention_mask=encoder_decoder_mask, + )[0] + + y = self.decoder_norm(y) + y = self.post_dropout(y) + + spec_out = self.spec_out(y) + return spec_out + + +class DecoderLayer(nn.Cell): + r""" + T5 decoder layer. + + Args: + d_model (`int`): + Size of the input hidden states. + d_kv (`int`): + Size of the key-value projection vectors. + num_heads (`int`): + Number of attention heads. + d_ff (`int`): + Size of the intermediate feed-forward layer. + dropout_rate (`float`): + Dropout probability. + layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__( + self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6 + ): + super().__init__() + layers = [] + + # cond self attention: layer 0 + layers.append( + T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate) + ) + + # cross attention: layer 1 + layers.append( + T5LayerCrossAttention( + d_model=d_model, + d_kv=d_kv, + num_heads=num_heads, + dropout_rate=dropout_rate, + layer_norm_epsilon=layer_norm_epsilon, + ) + ) + + # Film Cond MLP + dropout: last layer + layers.append( + T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon) + ) + + self.layer = nn.CellList(layers) + + def construct( + self, + hidden_states: ms.Tensor, + conditioning_emb: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + encoder_decoder_position_bias=None, + ) -> Tuple[ms.Tensor]: + hidden_states = self.layer[0]( + hidden_states, + conditioning_emb=conditioning_emb, + attention_mask=attention_mask, + ) + + if encoder_hidden_states is not None: + encoder_extended_attention_mask = ops.where( + encoder_attention_mask > 0, ms.Tensor(0.0), ms.Tensor(-1e10) + ).to(encoder_hidden_states.dtype) + + hidden_states = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_extended_attention_mask, + ) + + # Apply Film Conditional Feed Forward layer + hidden_states = self.layer[-1](hidden_states, conditioning_emb) + + return (hidden_states,) + + +class T5LayerSelfAttentionCond(nn.Cell): + r""" + T5 style self-attention layer with conditioning. + + Args: + d_model (`int`): + Size of the input hidden states. + d_kv (`int`): + Size of the key-value projection vectors. + num_heads (`int`): + Number of attention heads. + dropout_rate (`float`): + Dropout probability. + """ + + def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float): + super().__init__() + self.layer_norm = T5LayerNorm(d_model) + self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) + self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) + self.dropout = nn.Dropout(p=dropout_rate) + + def construct( + self, + hidden_states: ms.Tensor, + conditioning_emb: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + # pre_self_attention_layer_norm + normed_hidden_states = self.layer_norm(hidden_states) + + if conditioning_emb is not None: + normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb) + + # Self-attention block + attention_output = self.attention(normed_hidden_states) + + hidden_states = hidden_states + self.dropout(attention_output) + + return hidden_states + + +class T5LayerCrossAttention(nn.Cell): + r""" + T5 style cross-attention layer. + + Args: + d_model (`int`): + Size of the input hidden states. + d_kv (`int`): + Size of the key-value projection vectors. + num_heads (`int`): + Number of attention heads. + dropout_rate (`float`): + Dropout probability. + layer_norm_epsilon (`float`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float): + super().__init__() + self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False) + self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) + self.dropout = nn.Dropout(p=dropout_rate) + + def construct( + self, + hidden_states: ms.Tensor, + key_value_states: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.attention( + normed_hidden_states, + encoder_hidden_states=key_value_states, + attention_mask=attention_mask.squeeze(1), + ) + layer_output = hidden_states + self.dropout(attention_output) + return layer_output + + +class T5LayerFFCond(nn.Cell): + r""" + T5 style feed-forward conditional layer. + + Args: + d_model (`int`): + Size of the input hidden states. + d_ff (`int`): + Size of the intermediate feed-forward layer. + dropout_rate (`float`): + Dropout probability. + layer_norm_epsilon (`float`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate) + self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model) + self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon) + self.dropout = nn.Dropout(p=dropout_rate) + + def construct(self, hidden_states: ms.Tensor, conditioning_emb: Optional[ms.Tensor] = None) -> ms.Tensor: + forwarded_states = self.layer_norm(hidden_states) + if conditioning_emb is not None: + forwarded_states = self.film(forwarded_states, conditioning_emb) + + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Cell): + r""" + T5 style feed-forward layer with gated activations and dropout. + + Args: + d_model (`int`): + Size of the input hidden states. + d_ff (`int`): + Size of the intermediate feed-forward layer. + dropout_rate (`float`): + Dropout probability. + """ + + def __init__(self, d_model: int, d_ff: int, dropout_rate: float): + super().__init__() + self.wi_0 = nn.Dense(d_model, d_ff, has_bias=False) + self.wi_1 = nn.Dense(d_model, d_ff, has_bias=False) + self.wo = nn.Dense(d_ff, d_model, has_bias=False) + self.dropout = nn.Dropout(p=dropout_rate) + self.act = NewGELUActivation() + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerNorm(nn.Cell): + r""" + T5 style layer normalization module. + + Args: + hidden_size (`int`): + Size of the input hidden states. + eps (`float`, `optional`, defaults to `1e-6`): + A small value used for numerical stability to avoid dividing by zero. + """ + + def __init__(self, hidden_size: int, eps: float = 1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = ms.Parameter(ops.ones(hidden_size), name="weight") + self.variance_epsilon = eps + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(ms.float32).pow(2).mean(-1, keep_dims=True) + hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [ms.float16, ms.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class NewGELUActivation(nn.Cell): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + + def construct(self, input: ms.Tensor) -> ms.Tensor: + return ( + 0.5 * input * (1.0 + ops.tanh(float(math.sqrt(2.0 / math.pi)) * (input + 0.044715 * ops.pow(input, 3.0)))) + ) + + +class T5FiLMLayer(nn.Cell): + """ + T5 style FiLM Layer. + + Args: + in_features (`int`): + Number of input features. + out_features (`int`): + Number of output features. + """ + + def __init__(self, in_features: int, out_features: int): + super().__init__() + self.scale_bias = nn.Dense(in_features, out_features * 2, has_bias=False) + + def construct(self, x: ms.Tensor, conditioning_emb: ms.Tensor) -> ms.Tensor: + emb = self.scale_bias(conditioning_emb) + scale, shift = ops.chunk(emb, 2, -1) + x = x * (1 + scale) + shift + return x diff --git a/mindone/diffusers/models/transformers/transformer_temporal.py b/mindone/diffusers/models/transformers/transformer_temporal.py new file mode 100644 index 0000000000..74a51e87ff --- /dev/null +++ b/mindone/diffusers/models/transformers/transformer_temporal.py @@ -0,0 +1,368 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import mindspore as ms +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..normalization import GroupNorm +from ..resnet import AlphaBlender + + +@dataclass +class TransformerTemporalModelOutput(BaseOutput): + """ + The output of [`TransformerTemporalModel`]. + + Args: + sample (`ms.Tensor` of shape `(batch_size x num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. + """ + + sample: ms.Tensor + + +class TransformerTemporalModel(ModelMixin, ConfigMixin): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlock` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported + activation functions. + norm_elementwise_affine (`bool`, *optional*): + Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. + double_self_attention (`bool`, *optional*): + Configure if each `TransformerBlock` should contain two self-attention layers. + positional_embeddings: (`str`, *optional*): + The type of positional embeddings to apply to the sequence input before passing use. + num_positional_embeddings: (`int`, *optional*): + The maximum length of the sequence over which to apply positional embeddings. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Dense(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.CellList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + positional_embeddings=positional_embeddings, + num_positional_embeddings=num_positional_embeddings, + ) + for d in range(num_layers) + ] + ) + + self.proj_out = nn.Dense(inner_dim, in_channels) + + def construct( + self, + hidden_states: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, + timestep: Optional[ms.Tensor] = None, + class_labels: ms.Tensor = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, + ) -> TransformerTemporalModelOutput: + """ + The [`TransformerTemporal`] forward method. + + Args: + hidden_states (`ms.Tensor` of shape `(batch size, num latent pixels)` if discrete, + `ms.Tensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. + encoder_hidden_states ( `ms.Tensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `ms.Tensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `ms.Tensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + num_frames (`int`, *optional*, defaults to 1): + The number of frames to be processed per batch. This is used to reshape the hidden states. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is + returned, otherwise a `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :].reshape(batch_size, height, width, num_frames, channel).permute(0, 3, 4, 1, 2) + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) + + +class TransformerSpatioTemporalModel(nn.Cell): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + out_channels (`int`, *optional*): + The number of channels in the output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: int = 320, + out_channels: Optional[int] = None, + num_layers: int = 1, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + + # 2. Define input layers + self.in_channels = in_channels + self.norm = GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6) + self.proj_in = nn.Dense(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.CellList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for d in range(num_layers) + ] + ) + + time_mix_inner_dim = inner_dim + self.temporal_transformer_blocks = nn.CellList( + [ + TemporalBasicTransformerBlock( + inner_dim, + time_mix_inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for _ in range(num_layers) + ] + ) + + time_embed_dim = in_channels * 4 + self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels) + self.time_proj = Timesteps(in_channels, True, 0) + self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images") + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + # TODO: should use out_channels for continuous projections + self.proj_out = nn.Dense(inner_dim, in_channels) + + self.gradient_checkpointing = False + + def construct( + self, + hidden_states: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, + image_only_indicator: Optional[ms.Tensor] = None, + return_dict: bool = False, + ): + """ + Args: + hidden_states (`ms.Tensor` of shape `(batch size, channel, height, width)`): + Input hidden_states. + num_frames (`int`): + The number of frames to be processed per batch. This is used to reshape the hidden states. + encoder_hidden_states ( `ms.Tensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + image_only_indicator (`ms.Tensor` of shape `(batch size, num_frames)`, *optional*): + A tensor indicating whether the input contains only images. 1 indicates that the input contains only + images, 0 indicates that the input contains video frames. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain + tuple. + + Returns: + [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is + returned, otherwise a `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, _, height, width = hidden_states.shape + num_frames = image_only_indicator.shape[-1] + batch_size = batch_frames // num_frames + + time_context = encoder_hidden_states + time_context_first_timestep = time_context[None, :].reshape(batch_size, num_frames, -1, time_context.shape[-1])[ + :, 0 + ] + time_context = time_context_first_timestep[None, :].broadcast_to( + (height * width, batch_size, 1, time_context.shape[-1]) + ) + time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1]) + + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + num_frames_emb = ops.arange(num_frames) + num_frames_emb = num_frames_emb.tile((batch_size, 1)) + num_frames_emb = num_frames_emb.reshape(-1) + t_emb = self.time_proj(num_frames_emb) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + # 2. Blocks + for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states_mix = hidden_states + hidden_states_mix = hidden_states_mix + emb + + hidden_states_mix = temporal_block( + hidden_states_mix, + num_frames=num_frames, + encoder_hidden_states=time_context, + ) + hidden_states = self.time_mixer( + x_spatial=hidden_states, + x_temporal=hidden_states_mix, + image_only_indicator=image_only_indicator, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2) + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) diff --git a/mindone/diffusers/models/unets/__init__.py b/mindone/diffusers/models/unets/__init__.py index 5e2a1bd011..f5493f7b0e 100644 --- a/mindone/diffusers/models/unets/__init__.py +++ b/mindone/diffusers/models/unets/__init__.py @@ -1,4 +1,10 @@ from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel +from .unet_3d_condition import UNet3DConditionModel +from .unet_i2vgen_xl import I2VGenXLUNet +from .unet_kandinsky3 import Kandinsky3UNet +from .unet_motion_model import MotionAdapter, UNetMotionModel +from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel from .unet_stable_cascade import StableCascadeUNet +from .uvit_2d import UVit2DModel diff --git a/mindone/diffusers/models/unets/unet_2d.py b/mindone/diffusers/models/unets/unet_2d.py index e4a93b26aa..0278480847 100644 --- a/mindone/diffusers/models/unets/unet_2d.py +++ b/mindone/diffusers/models/unets/unet_2d.py @@ -20,7 +20,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput from ..activations import get_activation -from ..embeddings import TimestepEmbedding, Timesteps +from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from ..modeling_utils import ModelMixin from ..normalization import GroupNorm from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @@ -144,7 +144,8 @@ def __init__( # time if time_embedding_type == "fourier": - raise NotImplementedError("GaussianFourierProjection is not implemented") + self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16) + timestep_input_dim = 2 * block_out_channels[0] elif time_embedding_type == "positional": self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] @@ -352,6 +353,10 @@ def construct( if skip_sample is not None: sample += skip_sample + if self.config["time_embedding_type"] == "fourier": + timesteps = timesteps.reshape((sample.shape[0],) + (1,) * len(sample.shape[1:])) + sample = sample / timesteps + if not return_dict: return (sample,) diff --git a/mindone/diffusers/models/unets/unet_2d_blocks.py b/mindone/diffusers/models/unets/unet_2d_blocks.py index 3e430acce4..c2b909465d 100644 --- a/mindone/diffusers/models/unets/unet_2d_blocks.py +++ b/mindone/diffusers/models/unets/unet_2d_blocks.py @@ -13,13 +13,26 @@ # limitations under the License. from typing import Any, Dict, Optional, Tuple, Union +import numpy as np + import mindspore as ms from mindspore import nn, ops from ...utils import logging -from ..activations import get_activation -from ..attention_processor import Attention -from ..resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ..activations import SiLU, get_activation +from ..attention_processor import Attention, AttnAddedKVProcessor +from ..normalization import AdaGroupNorm, GroupNorm +from ..resnet import ( + Downsample2D, + FirDownsample2D, + FirUpsample2D, + KDownsample2D, + KUpsample2D, + ResnetBlock2D, + ResnetBlockCondNorm2D, + Upsample2D, +) +from ..transformers.dual_transformer_2d import DualTransformer2DModel from ..transformers.transformer_2d import Transformer2DModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -75,7 +88,20 @@ def get_down_block( resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "ResnetDownsampleBlock2D": - raise NotImplementedError + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) elif down_block_type == "AttnDownBlock2D": if add_downsample is False: downsample_type = None @@ -120,11 +146,52 @@ def get_down_block( attention_type=attention_type, ) elif down_block_type == "SimpleCrossAttnDownBlock2D": - raise NotImplementedError + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) elif down_block_type == "SkipDownBlock2D": - raise NotImplementedError + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) elif down_block_type == "AttnSkipDownBlock2D": - raise NotImplementedError + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) elif down_block_type == "DownEncoderBlock2D": return DownEncoderBlock2D( num_layers=num_layers, @@ -139,11 +206,44 @@ def get_down_block( resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnDownEncoderBlock2D": - raise NotImplementedError + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) elif down_block_type == "KDownBlock2D": - raise NotImplementedError + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) elif down_block_type == "KCrossAttnDownBlock2D": - raise NotImplementedError + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + add_self_attention=True if not add_downsample else False, + ) raise ValueError(f"{down_block_type} does not exist.") @@ -188,7 +288,21 @@ def get_mid_block( attention_type=attention_type, ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": - raise NotImplementedError + return UNetMidBlock2DSimpleCrossAttn( + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) elif mid_block_type == "UNetMidBlock2D": return UNetMidBlock2D( in_channels=in_channels, @@ -260,7 +374,22 @@ def get_up_block( resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "ResnetUpsampleBlock2D": - raise NotImplementedError + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") @@ -287,7 +416,28 @@ def get_up_block( attention_type=attention_type, ) elif up_block_type == "SimpleCrossAttnUpBlock2D": - raise NotImplementedError + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) elif up_block_type == "AttnUpBlock2D": if add_upsample is False: upsample_type = None @@ -310,9 +460,34 @@ def get_up_block( upsample_type=upsample_type, ) elif up_block_type == "SkipUpBlock2D": - raise NotImplementedError + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) elif up_block_type == "AttnSkipUpBlock2D": - raise NotImplementedError + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) elif up_block_type == "UpDecoderBlock2D": return UpDecoderBlock2D( num_layers=num_layers, @@ -328,11 +503,46 @@ def get_up_block( temb_channels=temb_channels, ) elif up_block_type == "AttnUpDecoderBlock2D": - raise NotImplementedError + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) elif up_block_type == "KUpBlock2D": - raise NotImplementedError + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) elif up_block_type == "KCrossAttnUpBlock2D": - raise NotImplementedError + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ) raise ValueError(f"{up_block_type} does not exist.") @@ -357,11 +567,11 @@ def __init__(self, in_channels: int, out_channels: int, act_fn: str): super().__init__() act_fn = get_activation(act_fn)() self.conv = nn.SequentialCell( - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, has_bias=True), + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, pad_mode="pad", has_bias=True), act_fn, - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, has_bias=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, pad_mode="pad", has_bias=True), act_fn, - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, has_bias=True), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, pad_mode="pad", has_bias=True), ) self.skip = ( nn.Conv2d(in_channels, out_channels, kernel_size=1, has_bias=False) @@ -425,13 +635,26 @@ def __init__( resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention self.has_cross_attention = False + self.num_layers = num_layers if attn_groups is None: attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None # there is always at least one resnet if resnet_time_scale_shift == "spatial": - raise NotImplementedError("ResnetBlockCondNorm2D is not implemented") + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] else: resnets = [ ResnetBlock2D( @@ -473,10 +696,23 @@ def __init__( ) ) else: - attentions.append(None) + # nn.CellList doesn't support append 'None', thus we have to modify construct() code to fit it + pass if resnet_time_scale_shift == "spatial": - raise NotImplementedError("ResnetBlockCondNorm2D is not implemented") + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) else: resnets.append( ResnetBlock2D( @@ -498,9 +734,13 @@ def __init__( def construct(self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None) -> ms.Tensor: hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: + # nn.CellList doesn't support append 'None', thus we have to modify code to fit it + for i in range(self.num_layers): + if self.add_attention: + attn = self.attentions[i] hidden_states = attn(hidden_states, temb=temb) + + resnet = self.resnets[i + 1] hidden_states = resnet(hidden_states, temb) return hidden_states @@ -530,6 +770,7 @@ def __init__( super().__init__() self.has_cross_attention = True + self.has_motion_modules = False self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) @@ -570,7 +811,16 @@ def __init__( ) ) else: - raise NotImplementedError("DualTransformer2DModel is not implemented") + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) resnets.append( ResnetBlock2D( in_channels=in_channels, @@ -626,11 +876,10 @@ def construct( return hidden_states -class AttnDownBlock2D(nn.Cell): +class UNetMidBlock2DSimpleCrossAttn(nn.Cell): def __init__( self, in_channels: int, - out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, @@ -641,28 +890,150 @@ def __init__( resnet_pre_norm: bool = True, attention_head_dim: int = 1, output_scale_factor: float = 1.0, - downsample_padding: int = 1, - downsample_type: str = "conv", + cross_attention_dim: int = 1280, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, ): super().__init__() - resnets = [] - attentions = [] - self.downsample_type = downsample_type - self.has_cross_attention = False - if attention_head_dim is None: - logger.warning( - f"It is not recommend to pass `attention_head_dim=None`. " - f"Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + self.has_cross_attention = True + + self.attention_head_dim = attention_head_dim + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.num_heads = in_channels // self.attention_head_dim + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, ) - attention_head_dim = out_channels + ] + attentions = [] - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock2D( - in_channels=in_channels, - out_channels=out_channels, + for _ in range(num_layers): + processor = AttnAddedKVProcessor() + + attentions.append( + Attention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + # attn + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + # resnet + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class AttnDownBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + downsample_type: str = "conv", + ): + super().__init__() + resnets = [] + attentions = [] + self.downsample_type = downsample_type + self.has_cross_attention = False + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. " + f"Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, @@ -814,7 +1185,16 @@ def __init__( ) ) else: - raise NotImplementedError("DualTransformer2DModel is not implemented") + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) self.attentions = nn.CellList(attentions) self.resnets = nn.CellList(resnets) @@ -988,7 +1368,19 @@ def __init__( for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels if resnet_time_scale_shift == "spatial": - raise NotImplementedError("ResnetBlockCondNorm2D is not implemented") + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) else: resnets.append( ResnetBlock2D( @@ -1029,14 +1421,11 @@ def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: return hidden_states -class AttnUpBlock2D(nn.Cell): +class AttnDownEncoderBlock2D(nn.Cell): def __init__( self, in_channels: int, - prev_output_channel: int, out_channels: int, - temb_channels: int, - resolution_idx: int = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, @@ -1046,40 +1435,52 @@ def __init__( resnet_pre_norm: bool = True, attention_head_dim: int = 1, output_scale_factor: float = 1.0, - upsample_type: str = "conv", + add_downsample: bool = True, + downsample_padding: int = 1, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = False - self.upsample_type = upsample_type if attention_head_dim is None: logger.warning( - f"It is not recommend to pass `attention_head_dim=None`. " - f"Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + in_channels = in_channels if i == 0 else out_channels + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) ) - ) attentions.append( Attention( out_channels, @@ -1098,104 +1499,66 @@ def __init__( self.attentions = nn.CellList(attentions) self.resnets = nn.CellList(resnets) - if upsample_type == "conv": - self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) - elif upsample_type == "resnet": - self.upsamplers = nn.CellList( + if add_downsample: + self.downsamplers = nn.CellList( [ - ResnetBlock2D( - in_channels=out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - up=True, + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: - self.upsamplers = None - - self.resolution_idx = resolution_idx + self.downsamplers = None - def construct( - self, - hidden_states: ms.Tensor, - res_hidden_states_tuple: Tuple[ms.Tensor, ...], - temb: Optional[ms.Tensor] = None, - upsample_size: Optional[int] = None, - ) -> ms.Tensor: + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: for resnet, attn in zip(self.resnets, self.attentions): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) - - hidden_states = resnet(hidden_states, temb) + hidden_states = resnet(hidden_states, temb=None) hidden_states = attn(hidden_states) - if self.upsamplers is not None: - for upsampler in self.upsamplers: - if self.upsample_type == "resnet": - hidden_states = upsampler(hidden_states, temb=temb) - else: - hidden_states = upsampler(hidden_states) + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) return hidden_states -class CrossAttnUpBlock2D(nn.Cell): +class AttnSkipDownBlock2D(nn.Cell): def __init__( self, in_channels: int, out_channels: int, - prev_output_channel: int, temb_channels: int, - resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, - transformer_layers_per_block: Union[int, Tuple[int]] = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", - resnet_groups: int = 32, resnet_pre_norm: bool = True, - num_attention_heads: int = 1, - cross_attention_dim: int = 1280, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - attention_type: str = "default", + attention_head_dim: int = 1, + output_scale_factor: float = np.sqrt(2.0), + add_downsample: bool = True, ): super().__init__() - resnets = [] - attentions = [] - - self.has_cross_attention = True - self.num_attention_heads = num_attention_heads + self.has_cross_attention = False + self.attentions = [] + self.resnets = [] - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * num_layers + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, + in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, - groups=resnet_groups, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, @@ -1203,43 +1566,1515 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - if not dual_cross_attention: - attentions.append( - Transformer2DModel( - num_attention_heads, - out_channels // num_attention_heads, - in_channels=out_channels, - num_layers=transformer_layers_per_block[i], - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + self.attentions = nn.CellList(self.attentions) + self.resnets = nn.CellList(self.resnets) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.CellList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + skip_sample: Optional[ms.Tensor] = None, + ) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...], ms.Tensor]: + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class SkipDownBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor: float = np.sqrt(2.0), + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + self.resnets = [] + self.has_cross_attention = False + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.CellList(self.resnets) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.CellList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + skip_sample: Optional[ms.Tensor] = None, + ) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...], ms.Tensor]: + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class ResnetDownsampleBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + skip_time_act: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.resnets = nn.CellList(resnets) + + if add_downsample: + self.downsamplers = nn.CellList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + self.has_cross_attention = False + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + ) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...]]: + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class SimpleCrossAttnDownBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + + self.has_cross_attention = True + + resnets = [] + attentions = [] + + self.attention_head_dim = attention_head_dim + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + processor = AttnAddedKVProcessor() + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + + if add_downsample: + self.downsamplers = nn.CellList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + ) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...]]: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + output_states = () + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class KDownBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + add_downsample: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.CellList(resnets) + + if add_downsample: + # YiYi's comments- might be able to use FirDownsample2D, look into details later + self.downsamplers = nn.CellList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + self.has_cross_attention = False + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + ) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...]]: + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class KCrossAttnDownBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + cross_attention_dim: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_group_size: int = 32, + add_downsample: bool = True, + attention_head_dim: int = 64, + add_self_attention: bool = False, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + out_channels, + out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + group_size=resnet_group_size, + ) + ) + + self.resnets = nn.CellList(resnets) + self.attentions = nn.CellList(attentions) + + if add_downsample: + self.downsamplers = nn.CellList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + ) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...]]: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.downsamplers is None: + output_states += (None,) + else: + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class AttnUpBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + upsample_type: str = "conv", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = False + self.upsample_type = upsample_type + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. " + f"Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + + if upsample_type == "conv": + self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + elif upsample_type == "resnet": + self.upsamplers = nn.CellList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + upsample_size: Optional[int] = None, + ) -> ms.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + if self.upsample_type == "resnet": + hidden_states = upsampler(hidden_states, temb=temb) + else: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, attention_type=attention_type, ) ) - else: - raise NotImplementedError("DualTransformer2DModel is not implemented") + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + + if add_upsample: + self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + self._gradient_checkpointing = False + + @property + def gradient_checkpointing(self): + return self._gradient_checkpointing + + @gradient_checkpointing.setter + def gradient_checkpointing(self, value): + self._gradient_checkpointing = value + for resnet in self.resnets: + resnet._recompute(value) + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[ms.Tensor] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + raise NotImplementedError("apply_freeu is not implemented") + + hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + self.has_cross_attention = False + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.CellList(resnets) + + if add_upsample: + self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + self._gradient_checkpointing = False + + @property + def gradient_checkpointing(self): + return self._gradient_checkpointing + + @gradient_checkpointing.setter + def gradient_checkpointing(self, value): + self._gradient_checkpointing = value + for resnet in self.resnets: + resnet._recompute(value) + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + upsample_size: Optional[int] = None, + ) -> ms.Tensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + raise NotImplementedError("apply_freeu is not implemented") + + hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) + + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpDecoderBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + ): + super().__init__() + resnets = [] + + self.has_cross_attention = False + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.CellList(resnets) + + if add_upsample: + self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def construct(self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None) -> ms.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnUpDecoderBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + ): + super().__init__() + resnets = [] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + self.has_cross_attention = False + + if add_upsample: + self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def construct(self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None) -> ms.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=temb) + hidden_states = attn(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnSkipUpBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = np.sqrt(2.0), + add_upsample: bool = True, + ): + super().__init__() + self.attentions = [] + self.resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(resnet_in_channels + res_skip_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d( + out_channels, 3, kernel_size=(3, 3), stride=(1, 1), pad_mode="pad", padding=(1, 1) + ) + self.skip_norm = GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + self.attentions = nn.CellList(self.attentions) + self.resnets = nn.CellList(self.resnets) + self.resolution_idx = resolution_idx + self.has_cross_attention = False + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + skip_sample=None, + ) -> Tuple[ms.Tensor, ms.Tensor]: + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) + + hidden_states = resnet(hidden_states, temb) + + hidden_states = self.attentions[0](hidden_states) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class SkipUpBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor: float = np.sqrt(2.0), + add_upsample: bool = True, + upsample_padding: int = 1, + ): + super().__init__() + self.resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min((resnet_in_channels + res_skip_channels) // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.resnets = nn.CellList(self.resnets) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d( + out_channels, 3, kernel_size=(3, 3), stride=(1, 1), pad_mode="pad", padding=(1, 1) + ) + self.skip_norm = GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + self.resolution_idx = resolution_idx + self.has_cross_attention = False + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + skip_sample=None, + ) -> Tuple[ms.Tensor, ms.Tensor]: + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) + + hidden_states = resnet(hidden_states, temb) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class ResnetUpsampleBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + skip_time_act: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.resnets = nn.CellList(resnets) + + if add_upsample: + self.upsamplers = nn.CellList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + self.has_cross_attention = False + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + upsample_size: Optional[int] = None, + ) -> ms.Tensor: + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) + + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states + + +class SimpleCrossAttnUpBlock2D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + processor = AttnAddedKVProcessor() + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) self.attentions = nn.CellList(attentions) self.resnets = nn.CellList(resnets) if add_upsample: - self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + self.upsamplers = nn.CellList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) else: self.upsamplers = None + self.gradient_checkpointing = False self.resolution_idx = resolution_idx - self._gradient_checkpointing = False - - @property - def gradient_checkpointing(self): - return self._gradient_checkpointing - - @gradient_checkpointing.setter - def gradient_checkpointing(self, value): - self._gradient_checkpointing = value - for resnet in self.resnets: - resnet._recompute(value) def construct( self, @@ -1247,107 +3082,97 @@ def construct( res_hidden_states_tuple: Tuple[ms.Tensor, ...], temb: Optional[ms.Tensor] = None, encoder_hidden_states: Optional[ms.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[ms.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[ms.Tensor] = None, ) -> ms.Tensor: - is_freeu_enabled = ( - getattr(self, "s1", None) - and getattr(self, "s2", None) - and getattr(self, "b1", None) - and getattr(self, "b2", None) - ) + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask for resnet, attn in zip(self.resnets, self.attentions): + # resnet # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - # FreeU: Only operate on the first two stages - if is_freeu_enabled: - raise NotImplementedError("apply_freeu is not implemented") - hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) hidden_states = resnet(hidden_states, temb) + hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - return_dict=False, - )[0] + attention_mask=mask, + **cross_attention_kwargs, + ) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states, temb) return hidden_states -class UpBlock2D(nn.Cell): +class KUpBlock2D(nn.Cell): def __init__( self, in_channels: int, - prev_output_channel: int, out_channels: int, temb_channels: int, - resolution_idx: Optional[int] = None, + resolution_idx: int, dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, + num_layers: int = 5, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: Optional[int] = 32, add_upsample: bool = True, ): super().__init__() resnets = [] - - self.has_cross_attention = False + k_in_channels = 2 * out_channels + k_out_channels = in_channels + num_layers = num_layers - 1 for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size resnets.append( - ResnetBlock2D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=k_out_channels if (i == num_layers - 1) else out_channels, temb_channels=temb_channels, eps=resnet_eps, - groups=resnet_groups, + groups=groups, + groups_out=groups_out, dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, ) ) self.resnets = nn.CellList(resnets) if add_upsample: - self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + self.upsamplers = nn.CellList([KUpsample2D()]) else: self.upsamplers = None + self.gradient_checkpointing = False self.resolution_idx = resolution_idx - self._gradient_checkpointing = False - - @property - def gradient_checkpointing(self): - return self._gradient_checkpointing - - @gradient_checkpointing.setter - def gradient_checkpointing(self, value): - self._gradient_checkpointing = value - for resnet in self.resnets: - resnet._recompute(value) + self.has_cross_attention = False def construct( self, @@ -1356,91 +3181,257 @@ def construct( temb: Optional[ms.Tensor] = None, upsample_size: Optional[int] = None, ) -> ms.Tensor: - is_freeu_enabled = ( - getattr(self, "s1", None) - and getattr(self, "s2", None) - and getattr(self, "b1", None) - and getattr(self, "b2", None) - ) + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = ops.cat([hidden_states, res_hidden_states_tuple], axis=1) for resnet in self.resnets: - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - - # FreeU: Only operate on the first two stages - if is_freeu_enabled: - raise NotImplementedError("apply_freeu is not implemented") - - hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) - hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = upsampler(hidden_states) return hidden_states -class UpDecoderBlock2D(nn.Cell): +class KCrossAttnUpBlock2D(nn.Cell): def __init__( self, in_channels: int, out_channels: int, - resolution_idx: Optional[int] = None, + temb_channels: int, + resolution_idx: int, dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + attention_head_dim: int = 1, # attention dim_head + cross_attention_dim: int = 768, add_upsample: bool = True, - temb_channels: Optional[int] = None, + upcast_attention: bool = False, ): super().__init__() resnets = [] + attentions = [] - self.has_cross_attention = False + is_first_block = in_channels == out_channels == temb_channels + is_middle_block = in_channels != out_channels + add_self_attention = True if is_first_block else False + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + # in_channels, and out_channels for the block (k-unet) + k_in_channels = out_channels if is_first_block else 2 * out_channels + k_out_channels = in_channels + + num_layers = num_layers - 1 for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size - if resnet_time_scale_shift == "spatial": - raise NotImplementedError("ResnetBlockCondNorm2D is not implemented") + if is_middle_block and (i == num_layers - 1): + conv_2d_out_channels = k_out_channels else: - resnets.append( - ResnetBlock2D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - ) + conv_2d_out_channels = None + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + conv_2d_out_channels=conv_2d_out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + k_out_channels if (i == num_layers - 1) else out_channels, + k_out_channels // attention_head_dim + if (i == num_layers - 1) + else out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + upcast_attention=upcast_attention, ) + ) self.resnets = nn.CellList(resnets) + self.attentions = nn.CellList(attentions) if add_upsample: - self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + self.upsamplers = nn.CellList([KUpsample2D()]) else: self.upsamplers = None + self.gradient_checkpointing = False self.resolution_idx = resolution_idx - def construct(self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None) -> ms.Tensor: - for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=temb) + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[ms.Tensor] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = ops.cat([hidden_states, res_hidden_states_tuple], axis=1) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states + + +# can potentially later be renamed to `No-feed-forward` attention +class KAttentionBlock(nn.Cell): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + attention_bias (`bool`, *optional*, defaults to `False`): + Configure if the attention layers should contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to upcast the attention computation to `float32`. + temb_channels (`int`, *optional*, defaults to 768): + The number of channels in the token embedding. + add_self_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to add self-attention to the block. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + group_size (`int`, *optional*, defaults to 32): + The number of groups to separate the channels into for group normalization. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + upcast_attention: bool = False, + temb_channels: int = 768, # for ada_group_norm + add_self_attention: bool = False, + cross_attention_norm: Optional[str] = None, + group_size: int = 32, + ): + super().__init__() + self.add_self_attention = add_self_attention + self.has_cross_attention = False + + # 1. Self-Attn + if add_self_attention: + self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + cross_attention_norm=None, + ) + + # 2. Cross-Attn + self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_attention_norm=cross_attention_norm, + ) + + def _to_3d(self, hidden_states: ms.Tensor, height: int, weight: int) -> ms.Tensor: + return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) + + def _to_4d(self, hidden_states: ms.Tensor, height: int, weight: int) -> ms.Tensor: + return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) + + def construct( + self, + hidden_states: ms.Tensor, + encoder_hidden_states: Optional[ms.Tensor] = None, + # TODO: mark emb as non-optional (self.norm2 requires it). + # requires assessing impact of change to positional param interface. + emb: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + # 1. Self-Attention + if self.add_self_attention: + norm_hidden_states = self.norm1(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention/None + norm_hidden_states = self.norm2(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + return hidden_states diff --git a/mindone/diffusers/models/unets/unet_2d_condition.py b/mindone/diffusers/models/unets/unet_2d_condition.py index fa92b30a66..674e1fb00b 100644 --- a/mindone/diffusers/models/unets/unet_2d_condition.py +++ b/mindone/diffusers/models/unets/unet_2d_condition.py @@ -620,7 +620,7 @@ def _set_class_embedding( elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + self.class_embedding = nn.Identity() elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( @@ -1036,7 +1036,7 @@ def construct( copied_cross_attention_kwargs = {} for k, v in cross_attention_kwargs.items(): if k == "gligen": - copied_cross_attention_kwargs[k] = {"obj": self.position_net(**v)} + copied_cross_attention_kwargs[k] = {"objs": self.position_net(**v)} else: copied_cross_attention_kwargs[k] = v cross_attention_kwargs = copied_cross_attention_kwargs @@ -1059,6 +1059,9 @@ def construct( is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets is_adapter = down_intrablock_additional_residuals is not None + # using variable `adapter_index` to get item in `down_intrablock_additional_residuals` for avoiding + # pop operations in construct(), which are not fully supported in GRAPH_MODE + adapter_index = 0 # maintain backward compatibility for legacy usage, where # T2I-Adapter and ControlNet both use down_block_additional_residuals arg # but can only use one or the other @@ -1079,8 +1082,9 @@ def construct( if downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} - if is_adapter and len(down_intrablock_additional_residuals) > 0: - additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + if is_adapter and len(down_intrablock_additional_residuals) > adapter_index: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals[adapter_index] + adapter_index += 1 sample, res_samples = downsample_block( hidden_states=sample, @@ -1093,8 +1097,9 @@ def construct( ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) - if is_adapter and len(down_intrablock_additional_residuals) > 0: - sample += down_intrablock_additional_residuals.pop(0) + if is_adapter and len(down_intrablock_additional_residuals) > adapter_index: + sample += down_intrablock_additional_residuals[adapter_index] + adapter_index += 1 down_block_res_samples += res_samples @@ -1126,10 +1131,11 @@ def construct( # To support T2I-Adapter-XL if ( is_adapter - and len(down_intrablock_additional_residuals) > 0 + and len(down_intrablock_additional_residuals) > adapter_index and sample.shape == down_intrablock_additional_residuals[0].shape ): - sample += down_intrablock_additional_residuals.pop(0) + sample += down_intrablock_additional_residuals[adapter_index] + adapter_index += 1 if is_controlnet: sample = sample + mid_block_additional_residual diff --git a/mindone/diffusers/models/unets/unet_3d_blocks.py b/mindone/diffusers/models/unets/unet_3d_blocks.py new file mode 100644 index 0000000000..b2d63c4c38 --- /dev/null +++ b/mindone/diffusers/models/unets/unet_3d_blocks.py @@ -0,0 +1,2093 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union + +import mindspore as ms +from mindspore import nn, ops + +from ...utils import logging +from ..attention import Attention +from ..resnet import Downsample2D, ResnetBlock2D, SpatioTemporalResBlock, TemporalConvLayer, Upsample2D +from ..transformers.dual_transformer_2d import DualTransformer2DModel +from ..transformers.transformer_2d import Transformer2DModel +from ..transformers.transformer_temporal import TransformerSpatioTemporalModel, TransformerTemporalModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + num_attention_heads: int, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + transformer_layers_per_block: int = 1, +) -> Union[ + "DownBlock3D", + "CrossAttnDownBlock3D", + "DownBlockMotion", + "CrossAttnDownBlockMotion", + "DownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", +]: + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + if down_block_type == "DownBlockMotion": + return DownBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_max_seq_length=temporal_max_seq_length, + ) + elif down_block_type == "CrossAttnDownBlockMotion": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion") + return CrossAttnDownBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_max_seq_length=temporal_max_seq_length, + ) + elif down_block_type == "DownBlockSpatioTemporal": + # added for SDV + return DownBlockSpatioTemporal( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "CrossAttnDownBlockSpatioTemporal": + # added for SDV + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal") + return CrossAttnDownBlockSpatioTemporal( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + add_downsample=add_downsample, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + ) + + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + num_attention_heads: int, + resolution_idx: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + temporal_num_attention_heads: int = 8, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + transformer_layers_per_block: int = 1, + dropout: float = 0.0, +) -> Union[ + "UpBlock3D", + "CrossAttnUpBlock3D", + "UpBlockMotion", + "CrossAttnUpBlockMotion", + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", +]: + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + ) + if up_block_type == "UpBlockMotion": + return UpBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_max_seq_length=temporal_max_seq_length, + ) + elif up_block_type == "CrossAttnUpBlockMotion": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion") + return CrossAttnUpBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_max_seq_length=temporal_max_seq_length, + ) + elif up_block_type == "UpBlockSpatioTemporal": + # added for SDV + return UpBlockSpatioTemporal( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + add_upsample=add_upsample, + ) + elif up_block_type == "CrossAttnUpBlockSpatioTemporal": + # added for SDV + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal") + return CrossAttnUpBlockSpatioTemporal( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + add_upsample=add_upsample, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + resolution_idx=resolution_idx, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Cell): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ] + attentions = [] + temp_attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + in_channels // num_attention_heads, + num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + in_channels // num_attention_heads, + num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.CellList(resnets) + self.temp_convs = nn.CellList(temp_convs) + self.attentions = nn.CellList(attentions) + self.temp_attentions = nn.CellList(temp_attentions) + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> ms.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + attentions.append( + Transformer2DModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.CellList(resnets) + self.temp_convs = nn.CellList(temp_convs) + self.attentions = nn.CellList(attentions) + self.temp_attentions = nn.CellList(temp_attentions) + + if add_downsample: + self.downsamplers = nn.CellList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Dict[str, Any] = None, + ) -> Union[ms.Tensor, Tuple[ms.Tensor, ...]]: + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.CellList(resnets) + self.temp_convs = nn.CellList(temp_convs) + + if add_downsample: + self.downsamplers = nn.CellList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + self.has_cross_attention = False + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + num_frames: int = 1, + ) -> Union[ms.Tensor, Tuple[ms.Tensor, ...]]: + output_states = () + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resolution_idx: Optional[int] = None, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + attentions.append( + Transformer2DModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.CellList(resnets) + self.temp_convs = nn.CellList(temp_convs) + self.attentions = nn.CellList(attentions) + self.temp_attentions = nn.CellList(temp_attentions) + + if add_upsample: + self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[ms.Tensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Dict[str, Any] = None, + ) -> ms.Tensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + # TODO(Patrick, William) - attention mask is not used + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + raise NotImplementedError("apply_freeu is not implemented") + + hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Cell): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + resolution_idx: Optional[int] = None, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.CellList(resnets) + self.temp_convs = nn.CellList(temp_convs) + + if add_upsample: + self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + self.has_cross_attention = False + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + upsample_size: Optional[int] = None, + num_frames: int = 1, + ) -> ms.Tensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + raise NotImplementedError("apply_freeu is not implemented") + + hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class DownBlockMotion(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + temporal_num_attention_heads: int = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + TransformerTemporalModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.resnets = nn.CellList(resnets) + self.motion_modules = nn.CellList(motion_modules) + + if add_downsample: + self.downsamplers = nn.CellList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + self.has_cross_attention = False + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + num_frames: int = 1, + ) -> Union[ms.Tensor, Tuple[ms.Tensor, ...]]: + output_states = () + + blocks = zip(self.resnets, self.motion_modules) + for resnet, motion_module in blocks: + hidden_states = resnet(hidden_states, temb) + hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockMotion(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + + motion_modules.append( + TransformerTemporalModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + self.motion_modules = nn.CellList(motion_modules) + + if add_downsample: + self.downsamplers = nn.CellList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + num_frames: int = 1, + encoder_attention_mask: Optional[ms.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + additional_residuals: Optional[ms.Tensor] = None, + ): + output_states = () + + blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) + for i, (resnet, attn, motion_module) in enumerate(blocks): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlockMotion(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + motion_modules.append( + TransformerTemporalModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + self.motion_modules = nn.CellList(motion_modules) + + if add_upsample: + self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[ms.Tensor] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + num_frames: int = 1, + ) -> ms.Tensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.attentions, self.motion_modules) + for resnet, attn, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + raise NotImplementedError("apply_freeu is not implemented") + + hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlockMotion(nn.Cell): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temporal_norm_num_groups: int = 32, + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + motion_modules.append( + TransformerTemporalModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=temporal_norm_num_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.resnets = nn.CellList(resnets) + self.motion_modules = nn.CellList(motion_modules) + + if add_upsample: + self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + self.has_cross_attention = False + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + upsample_size=None, + num_frames: int = 1, + ) -> ms.Tensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.motion_modules) + + for resnet, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + raise NotImplementedError("apply_freeu is not implemented") + + hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UNetMidBlockCrossAttnMotion(nn.Cell): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_num_attention_heads: int = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + ): + super().__init__() + + self.has_cross_attention = True + self.has_motion_modules = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + motion_modules = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + TransformerTemporalModel( + num_attention_heads=temporal_num_attention_heads, + attention_head_dim=in_channels // temporal_num_attention_heads, + in_channels=in_channels, + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + activation_fn="geglu", + ) + ) + + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + self.motion_modules = nn.CellList(motion_modules) + + self.gradient_checkpointing = False + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[ms.Tensor] = None, + num_frames: int = 1, + ) -> ms.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + + blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) + for attn, resnet, motion_module in blocks: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class MidBlockTemporalDecoder(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + attention_head_dim: int = 512, + num_layers: int = 1, + upcast_attention: bool = False, + ): + super().__init__() + + resnets = [] + attentions = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=1e-6, + temporal_eps=1e-5, + merge_factor=0.0, + merge_strategy="learned", + switch_spatial_to_temporal_mix=True, + ) + ) + + attentions.append( + Attention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + eps=1e-6, + upcast_attention=upcast_attention, + norm_num_groups=32, + bias=True, + residual_connection=True, + ) + ) + + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + self.has_cross_attention = False + + def construct( + self, + hidden_states: ms.Tensor, + image_only_indicator: ms.Tensor, + ): + hidden_states = self.resnets[0]( + hidden_states, + image_only_indicator=image_only_indicator, + ) + for resnet, attn in zip(self.resnets[1:], self.attentions): + hidden_states = attn(hidden_states) + hidden_states = resnet( + hidden_states, + image_only_indicator=image_only_indicator, + ) + + return hidden_states + + +class UpBlockTemporalDecoder(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=1e-6, + temporal_eps=1e-5, + merge_factor=0.0, + merge_strategy="learned", + switch_spatial_to_temporal_mix=True, + ) + ) + self.resnets = nn.CellList(resnets) + + if add_upsample: + self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.has_cross_attention = False + + def construct( + self, + hidden_states: ms.Tensor, + image_only_indicator: ms.Tensor, + ) -> ms.Tensor: + for resnet in self.resnets: + hidden_states = resnet( + hidden_states, + image_only_indicator=image_only_indicator, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UNetMidBlockSpatioTemporal(nn.Cell): + def __init__( + self, + in_channels: int, + temb_channels: int, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + # there is always at least one resnet + resnets = [ + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ] + attentions = [] + + for i in range(num_layers): + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ) + + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + + self.gradient_checkpointing = False + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + image_only_indicator: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + hidden_states = self.resnets[0]( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + return hidden_states + + +class DownBlockSpatioTemporal(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + num_layers: int = 1, + add_downsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ) + + self.resnets = nn.CellList(resnets) + + if add_downsample: + self.downsamplers = nn.CellList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + self.has_cross_attention = False + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + image_only_indicator: Optional[ms.Tensor] = None, + ) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...]]: + output_states = () + for resnet in self.resnets: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockSpatioTemporal(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + add_downsample: bool = True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=1e-6, + ) + ) + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + + if add_downsample: + self.downsamplers = nn.CellList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=1, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def construct( + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + image_only_indicator: Optional[ms.Tensor] = None, + ) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...]]: + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + for resnet, attn in blocks: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class UpBlockSpatioTemporal(nn.Cell): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + num_layers: int = 1, + resnet_eps: float = 1e-6, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + ) + ) + + self.resnets = nn.CellList(resnets) + + if add_upsample: + self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + self.has_cross_attention = False + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + image_only_indicator: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) + + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpBlockSpatioTemporal(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + ) + ) + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + + if add_upsample: + self.upsamplers = nn.CellList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + image_only_indicator: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) + + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states diff --git a/mindone/diffusers/models/unets/unet_3d_condition.py b/mindone/diffusers/models/unets/unet_3d_condition.py new file mode 100644 index 0000000000..177f0136a3 --- /dev/null +++ b/mindone/diffusers/models/unets/unet_3d_condition.py @@ -0,0 +1,605 @@ +# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# Copyright 2024 The ModelScope Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import mindspore as ms +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import BaseOutput, logging +from ..activations import get_activation +from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..normalization import GroupNorm +from ..transformers.transformer_temporal import TransformerTemporalModel +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + """ + The output of [`UNet3DConditionModel`]. + + Args: + sample (`ms.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: ms.Tensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): The number of attention heads. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types: Tuple[str, ...] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1024, + attention_head_dim: Union[int, Tuple[int]] = 64, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise NotImplementedError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in" + "https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. " + "Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in + # https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. " + f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. " + f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. " + f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + pad_mode="pad", + padding=conv_in_padding, + has_bias=True, + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=attention_head_dim, + in_channels=block_out_channels[0], + num_layers=1, + norm_num_groups=norm_num_groups, + ) + + # class embedding + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + self.down_blocks = [] + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=False, + ) + self.down_blocks.append(down_block) + self.down_blocks = nn.CellList(self.down_blocks) + + # mid + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + self.up_blocks = [] + layers_per_resnet_in_up_blocks = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=False, + resolution_idx=i, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + layers_per_resnet_in_up_blocks.append(len(up_block.resnets)) + self.up_blocks = nn.CellList(self.up_blocks) + self.layers_per_resnet_in_up_blocks = layers_per_resnet_in_up_blocks + + # out + if norm_num_groups is not None: + self.conv_norm_out = GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = get_activation("silu")() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], + out_channels, + kernel_size=conv_out_kernel, + pad_mode="pad", + padding=conv_out_padding, + has_bias=True, + ) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.name_cells().items(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.name_cells().items(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: nn.Cell, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.name_cells().items(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.name_cells().items(): + fn_recursive_attn_processor(name, module, processor) + + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.name_cells().items(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.name_cells().items(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.name_cells().items(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.name_cells().items(): + fn_recursive_feed_forward(module, None, 0) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def construct( + self, + sample: ms.Tensor, + timestep: Union[ms.Tensor, float, int], + encoder_hidden_states: ms.Tensor, + class_labels: Optional[ms.Tensor] = None, + timestep_cond: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[ms.Tensor]] = None, + mid_block_additional_residual: Optional[ms.Tensor] = None, + return_dict: bool = False, + ) -> Union[UNet3DConditionOutput, Tuple[ms.Tensor]]: + r""" + The [`UNet3DConditionModel`] forward method. + + Args: + sample (`ms.Tensor`): + The noisy input tensor with the following shape `(batch, num_channels, num_frames, height, width`. + timestep (`ms.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`ms.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`ms.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`ms.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`ms.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + down_block_additional_residuals: (`tuple` of `ms.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`ms.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + + Returns: + [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if sample.shape[-2] % default_overall_up_factor != 0 or sample.shape[-1] % default_overall_up_factor != 0: + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not ops.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + if isinstance(timestep, float): + dtype = ms.float64 + else: + dtype = ms.int64 + timesteps = ms.Tensor([timesteps], dtype=dtype) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None] + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.broadcast_to((sample.shape[0],)) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + + sample = self.transformer_in( + sample, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-self.layers_per_resnet_in_up_blocks[i] :] + down_block_res_samples = down_block_res_samples[: -self.layers_per_resnet_in_up_blocks[i]] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) diff --git a/mindone/diffusers/models/unets/unet_i2vgen_xl.py b/mindone/diffusers/models/unets/unet_i2vgen_xl.py new file mode 100644 index 0000000000..3b5c3272eb --- /dev/null +++ b/mindone/diffusers/models/unets/unet_i2vgen_xl.py @@ -0,0 +1,650 @@ +# Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union + +import mindspore as ms +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import logging +from ..activations import SiLU, get_activation +from ..attention import Attention, FeedForward +from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..normalization import GroupNorm, LayerNorm +from ..transformers.transformer_temporal import TransformerTemporalModel +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) +from .unet_3d_condition import UNet3DConditionOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class I2VGenXLTransformerTemporalEncoder(nn.Cell): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + activation_fn: str = "geglu", + upcast_attention: bool = False, + ff_inner_dim: Optional[int] = None, + dropout: int = 0.0, + ): + super().__init__() + self.norm1 = LayerNorm(dim, elementwise_affine=True, eps=1e-5) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=True, + ) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=False, + inner_dim=ff_inner_dim, + bias=True, + ) + + def construct( + self, + hidden_states: ms.Tensor, + ) -> ms.Tensor: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + ff_output = self.ff(hidden_states) + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + I2VGenXL UNet. It is a conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep + and returns a sample-shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 64): Attention head dim. + num_attention_heads (`int`, *optional*): The number of attention heads. + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types: Tuple[str, ...] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + norm_num_groups: Optional[int] = 32, + cross_attention_dim: int = 1024, + attention_head_dim: Union[int, Tuple[int]] = 64, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + ): + super().__init__() + + # When we first integrated the UNet into the library, we didn't have `attention_head_dim`. As a consequence + # of that, we used `num_attention_heads` for arguments that actually denote attention head dimension. This + # is why we ignore `num_attention_heads` and calculate it from `attention_head_dims` below. + # This is still an incorrect way of calculating `num_attention_heads` but we need to stick to it + # without running proper depcrecation cycles for the {down,mid,up} blocks which are a + # part of the public API. + num_attention_heads = attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. " + f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. " + f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. " + f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # input + self.conv_in = nn.Conv2d( + in_channels + in_channels, block_out_channels[0], kernel_size=3, pad_mode="pad", padding=1, has_bias=True + ) + + self.transformer_in = TransformerTemporalModel( + num_attention_heads=8, + attention_head_dim=num_attention_heads, + in_channels=block_out_channels[0], + num_layers=1, + norm_num_groups=norm_num_groups, + ) + + # image embedding + self.image_latents_proj_in = nn.SequentialCell( + nn.Conv2d(4, in_channels * 4, 3, pad_mode="pad", padding=1, has_bias=True), + SiLU(), + nn.Conv2d(in_channels * 4, in_channels * 4, 3, stride=1, pad_mode="pad", padding=1, has_bias=True), + SiLU(), + nn.Conv2d(in_channels * 4, in_channels, 3, stride=1, pad_mode="pad", padding=1, has_bias=True), + ) + self.image_latents_temporal_encoder = I2VGenXLTransformerTemporalEncoder( + dim=in_channels, + num_attention_heads=2, + ff_inner_dim=in_channels * 4, + attention_head_dim=in_channels, + activation_fn="gelu", + ) + self.image_latents_context_embedding = nn.SequentialCell( + nn.Conv2d(4, in_channels * 8, 3, pad_mode="pad", padding=1, has_bias=True), + SiLU(), + nn.AdaptiveAvgPool2d((32, 32)), + nn.Conv2d(in_channels * 8, in_channels * 16, 3, stride=2, pad_mode="pad", padding=1, has_bias=True), + SiLU(), + nn.Conv2d(in_channels * 16, cross_attention_dim, 3, stride=2, pad_mode="pad", padding=1, has_bias=True), + ) + + # other embeddings -- time, context, fps, etc. + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn="silu") + self.context_embedding = nn.SequentialCell( + nn.Dense(cross_attention_dim, time_embed_dim), + SiLU(), + nn.Dense(time_embed_dim, cross_attention_dim * in_channels), + ) + self.fps_embedding = nn.SequentialCell( + nn.Dense(timestep_input_dim, time_embed_dim), SiLU(), nn.Dense(time_embed_dim, time_embed_dim) + ) + + # blocks + self.down_blocks = [] + self.up_blocks = [] + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=1e-05, + resnet_act_fn="silu", + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + downsample_padding=1, + dual_cross_attention=False, + ) + self.down_blocks.append(down_block) + self.down_blocks = nn.CellList(self.down_blocks) + + # mid + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=1e-05, + resnet_act_fn="silu", + output_scale_factor=1, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + layers_per_resnet_in_up_blocks = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=1e-05, + resnet_act_fn="silu", + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=False, + resolution_idx=i, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + layers_per_resnet_in_up_blocks.append(len(up_block.resnets)) + self.up_blocks = nn.CellList(self.up_blocks) + self.layers_per_resnet_in_up_blocks = layers_per_resnet_in_up_blocks + + # out + self.conv_norm_out = GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-05) + self.conv_act = get_activation("silu")() + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=3, pad_mode="pad", padding=1, has_bias=True + ) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: # type: ignore + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str, AttentionProcessor]): # type: ignore + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.name_cells().items(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.name_cells().items(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): # type: ignore + r""" + Sets the attention processor to use to compute attention. + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: nn.Cell, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.name_cells().items(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.name_cells().items(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.name_cells().items(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.name_cells().items(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.name_cells().items(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.name_cells().items(): + fn_recursive_feed_forward(module, None, 0) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel._set_gradient_checkpointing + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def construct( + self, + sample: ms.Tensor, + timestep: Union[ms.Tensor, float, int], + fps: ms.Tensor, + image_latents: ms.Tensor, + image_embeddings: Optional[ms.Tensor] = None, + encoder_hidden_states: Optional[ms.Tensor] = None, + timestep_cond: Optional[ms.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = False, + ) -> Union[UNet3DConditionOutput, Tuple[ms.Tensor]]: + r""" + The [`I2VGenXLUNet`] forward method. + + Args: + sample (`ms.Tensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. + timestep (`ms.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + fps (`ms.Tensor`): Frames per second for the video being generated. Used as a "micro-condition". + image_latents (`ms.Tensor`): Image encodings from the VAE. + image_embeddings (`ms.Tensor`): Projection embeddings of the conditioning image computed with a vision encoder. + encoder_hidden_states (`ms.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + batch_size, channels, num_frames, height, width = sample.shape + + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if sample.shape[-1] % default_overall_up_factor != 0 or sample.shape[-2] % default_overall_up_factor != 0: + forward_upsample_size = True + + # 1. time + timesteps = timestep + if not ops.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can + if isinstance(timesteps, float): + dtype = ms.float64 + else: + dtype = ms.int64 + timesteps = ms.Tensor([timesteps], dtype=dtype) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None] + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.broadcast_to((sample.shape[0],)) + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + t_emb = self.time_embedding(t_emb, timestep_cond) + + # 2. FPS + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + fps = fps.broadcast_to((fps.shape[0],)) + fps_emb = self.fps_embedding(self.time_proj(fps).to(dtype=self.dtype)) + + # 3. time + FPS embeddings. + emb = t_emb + fps_emb + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + + # 4. context embeddings. + # The context embeddings consist of both text embeddings from the input prompt + # AND the image embeddings from the input image. For images, both VAE encodings + # and the CLIP image embeddings are incorporated. + # So the final `context_embeddings` becomes the query for cross-attention. + context_emb = sample.new_zeros((batch_size, 0, self.config["cross_attention_dim"])) + context_emb = context_emb.to(sample.dtype) + context_emb = ops.cat([context_emb, encoder_hidden_states], axis=1) + + image_latents_for_context_embds = image_latents[:, :, :1, :] + image_latents_context_embs = image_latents_for_context_embds.permute(0, 2, 1, 3, 4).reshape( + image_latents_for_context_embds.shape[0] * image_latents_for_context_embds.shape[2], + image_latents_for_context_embds.shape[1], + image_latents_for_context_embds.shape[3], + image_latents_for_context_embds.shape[4], + ) + image_latents_context_embs = self.image_latents_context_embedding(image_latents_context_embs) + + _batch_size, _channels, _height, _width = image_latents_context_embs.shape + image_latents_context_embs = image_latents_context_embs.permute(0, 2, 3, 1).reshape( + _batch_size, _height * _width, _channels + ) + context_emb = ops.cat([context_emb, image_latents_context_embs], axis=1) + + image_emb = self.context_embedding(image_embeddings) + image_emb = image_emb.view(-1, self.config["in_channels"], self.config["cross_attention_dim"]) + context_emb = ops.cat([context_emb, image_emb], axis=1) + context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0) + + image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape( + image_latents.shape[0] * image_latents.shape[2], + image_latents.shape[1], + image_latents.shape[3], + image_latents.shape[4], + ) + image_latents = self.image_latents_proj_in(image_latents) + image_latents = ( + image_latents[None, :] + .reshape(batch_size, num_frames, channels, height, width) + .permute(0, 3, 4, 1, 2) + .reshape(batch_size * height * width, num_frames, channels) + ) + image_latents = self.image_latents_temporal_encoder(image_latents) + image_latents = image_latents.reshape(batch_size, height, width, num_frames, channels).permute(0, 4, 3, 1, 2) + + # 5. pre-process + sample = ops.cat([sample, image_latents], axis=1) + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + sample = self.transformer_in( + sample, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # 6. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=context_emb, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + # 7. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=context_emb, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + # 8. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-self.layers_per_resnet_in_up_blocks[i] :] + down_block_res_samples = down_block_res_samples[: -self.layers_per_resnet_in_up_blocks[i]] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=context_emb, + upsample_size=upsample_size, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 9. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) diff --git a/mindone/diffusers/models/unets/unet_kandinsky3.py b/mindone/diffusers/models/unets/unet_kandinsky3.py new file mode 100644 index 0000000000..5848629ad5 --- /dev/null +++ b/mindone/diffusers/models/unets/unet_kandinsky3.py @@ -0,0 +1,539 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Dict, Tuple, Union + +import mindspore as ms +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput, logging +from ..activations import SiLU +from ..attention_processor import Attention, AttentionProcessor, AttnProcessor +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..normalization import GroupNorm, LayerNorm + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class Kandinsky3UNetOutput(BaseOutput): + sample: ms.Tensor = None + + +class Kandinsky3EncoderProj(nn.Cell): + def __init__(self, encoder_hid_dim, cross_attention_dim): + super().__init__() + self.projection_linear = nn.Dense(encoder_hid_dim, cross_attention_dim, has_bias=False) + self.projection_norm = LayerNorm(cross_attention_dim) + + def construct(self, x): + x = self.projection_linear(x) + x = self.projection_norm(x) + return x + + +class Kandinsky3UNet(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + in_channels: int = 4, + time_embedding_dim: int = 1536, + groups: int = 32, + attention_head_dim: int = 64, + layers_per_block: Union[int, Tuple[int]] = 3, + block_out_channels: Tuple[int] = (384, 768, 1536, 3072), + cross_attention_dim: Union[int, Tuple[int]] = 4096, + encoder_hid_dim: int = 4096, + ): + super().__init__() + + # TOOD(Yiyi): Give better name and put into config for the following 4 parameters + expansion_ratio = 4 + compression_ratio = 2 + add_cross_attention = (False, True, True, True) + add_self_attention = (False, True, True, True) + + out_channels = in_channels + init_channels = block_out_channels[0] // 2 + self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1) + + self.time_embedding = TimestepEmbedding( + init_channels, + time_embedding_dim, + ) + + self.add_time_condition = Kandinsky3AttentionPooling( + time_embedding_dim, cross_attention_dim, attention_head_dim + ) + + self.conv_in = nn.Conv2d(in_channels, init_channels, kernel_size=3, pad_mode="pad", padding=1, has_bias=True) + + self.encoder_hid_proj = Kandinsky3EncoderProj(encoder_hid_dim, cross_attention_dim) + + hidden_dims = [init_channels] + list(block_out_channels) + in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:])) + text_dims = [cross_attention_dim if is_exist else None for is_exist in add_cross_attention] + num_blocks = len(block_out_channels) * [layers_per_block] + layer_params = [num_blocks, text_dims, add_self_attention] + rev_layer_params = map(reversed, layer_params) + + cat_dims = [] + self.num_levels = len(in_out_dims) + self.down_blocks = [] + for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate( + zip(in_out_dims, *layer_params) + ): + down_sample = level != (self.num_levels - 1) + cat_dims.append(out_dim if level != (self.num_levels - 1) else 0) + self.down_blocks.append( + Kandinsky3DownSampleBlock( + in_dim, + out_dim, + time_embedding_dim, + text_dim, + res_block_num, + groups, + attention_head_dim, + expansion_ratio, + compression_ratio, + down_sample, + self_attention, + ) + ) + self.down_blocks = nn.CellList(self.down_blocks) + + self.up_blocks = [] + for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate( + zip(reversed(in_out_dims), *rev_layer_params) + ): + up_sample = level != 0 + self.up_blocks.append( + Kandinsky3UpSampleBlock( + in_dim, + cat_dims.pop(), + out_dim, + time_embedding_dim, + text_dim, + res_block_num, + groups, + attention_head_dim, + expansion_ratio, + compression_ratio, + up_sample, + self_attention, + ) + ) + self.up_blocks = nn.CellList(self.up_blocks) + + self.conv_norm_out = GroupNorm(groups, init_channels) + self.conv_act_out = SiLU() + self.conv_out = nn.Conv2d(init_channels, out_channels, kernel_size=3, pad_mode="pad", padding=1, has_bias=True) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.name_cells().items(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.name_cells().items(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: nn.Cell, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.name_cells().items(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.name_cells().items(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def construct(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=False): + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if not ops.is_tensor(timestep): + dtype = ms.float32 if isinstance(timestep, float) else ms.int32 + timestep = ms.Tensor([timestep], dtype=dtype) + elif len(timestep.shape) == 0: + timestep = timestep[None] + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = timestep.broadcast_to((sample.shape[0],)) + time_embed_input = self.time_proj(timestep).to(sample.dtype) + time_embed = self.time_embedding(time_embed_input) + + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + + if encoder_hidden_states is not None: + time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask) + + hidden_states = [] + sample = self.conv_in(sample) + for level, down_sample in enumerate(self.down_blocks): + sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask) + if level != self.num_levels - 1: + hidden_states.append(sample) + + for level, up_sample in enumerate(self.up_blocks): + if level != 0: + sample = ops.cat([sample, hidden_states[-level]], axis=1) + sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask) + + sample = self.conv_norm_out(sample) + sample = self.conv_act_out(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + return Kandinsky3UNetOutput(sample=sample) + + +class Kandinsky3UpSampleBlock(nn.Cell): + def __init__( + self, + in_channels, + cat_dim, + out_channels, + time_embed_dim, + context_dim=None, + num_blocks=3, + groups=32, + head_dim=64, + expansion_ratio=4, + compression_ratio=2, + up_sample=True, + self_attention=True, + ): + super().__init__() + up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1) + hidden_channels = ( + [(in_channels + cat_dim, in_channels)] + + [(in_channels, in_channels)] * (num_blocks - 2) + + [(in_channels, out_channels)] + ) + attentions = [] + resnets_in = [] + resnets_out = [] + + self.self_attention = self_attention + self.context_dim = context_dim + + if self_attention: + attentions.append( + Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio) + ) + else: + attentions.append(nn.Identity()) + + for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): + resnets_in.append( + Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution) + ) + + if context_dim is not None: + attentions.append( + Kandinsky3AttentionBlock(in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio) + ) + else: + attentions.append(nn.Identity()) + + resnets_out.append( + Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) + ) + + self.attentions = nn.CellList(attentions) + self.resnets_in = nn.CellList(resnets_in) + self.resnets_out = nn.CellList(resnets_out) + + def construct(self, x, time_embed, context=None, context_mask=None, image_mask=None): + for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out): + x = resnet_in(x, time_embed) + if self.context_dim is not None: + x = attention(x, time_embed, context, context_mask, image_mask) + x = resnet_out(x, time_embed) + + if self.self_attention: + x = self.attentions[0](x, time_embed, image_mask=image_mask) + return x + + +class Kandinsky3DownSampleBlock(nn.Cell): + def __init__( + self, + in_channels, + out_channels, + time_embed_dim, + context_dim=None, + num_blocks=3, + groups=32, + head_dim=64, + expansion_ratio=4, + compression_ratio=2, + down_sample=True, + self_attention=True, + ): + super().__init__() + attentions = [] + resnets_in = [] + resnets_out = [] + + self.self_attention = self_attention + self.context_dim = context_dim + + if self_attention: + attentions.append( + Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio) + ) + else: + attentions.append(nn.Identity()) + + up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]] + hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1) + for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): + resnets_in.append(Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)) + + if context_dim is not None: + attentions.append( + Kandinsky3AttentionBlock( + out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio + ) + ) + else: + attentions.append(nn.Identity()) + + resnets_out.append( + Kandinsky3ResNetBlock( + out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution + ) + ) + + self.attentions = nn.CellList(attentions) + self.resnets_in = nn.CellList(resnets_in) + self.resnets_out = nn.CellList(resnets_out) + + def construct(self, x, time_embed, context=None, context_mask=None, image_mask=None): + if self.self_attention: + x = self.attentions[0](x, time_embed, image_mask=image_mask) + + for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out): + x = resnet_in(x, time_embed) + if self.context_dim is not None: + x = attention(x, time_embed, context, context_mask, image_mask) + x = resnet_out(x, time_embed) + return x + + +class Kandinsky3ConditionalGroupNorm(nn.Cell): + def __init__(self, groups, normalized_shape, context_dim): + super().__init__() + self.norm = GroupNorm(groups, normalized_shape, affine=False) + self.context_mlp = nn.SequentialCell( + SiLU(), nn.Dense(context_dim, 2 * normalized_shape, weight_init="zeros", bias_init="zeros") + ) + + def construct(self, x, context): + context = self.context_mlp(context) + + for _ in range(len(x.shape[2:])): + context = context.unsqueeze(-1) + + scale, shift = context.chunk(2, axis=1) + x = self.norm(x) * (scale + 1.0) + shift + return x + + +class Kandinsky3Block(nn.Cell): + def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None): + super().__init__() + self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim) + self.activation = SiLU() + if up_resolution is not None and up_resolution: + self.up_sample = nn.Conv2dTranspose( + in_channels, in_channels, kernel_size=2, stride=2, pad_mode="pad", has_bias=True + ) + else: + self.up_sample = nn.Identity() + + padding = int(kernel_size > 1) + self.projection = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, pad_mode="pad", padding=padding, has_bias=True + ) + + if up_resolution is not None and not up_resolution: + self.down_sample = nn.Conv2d( + out_channels, out_channels, kernel_size=2, stride=2, pad_mode="pad", has_bias=True + ) + else: + self.down_sample = nn.Identity() + + def construct(self, x, time_embed): + x = self.group_norm(x, time_embed) + x = self.activation(x) + x = self.up_sample(x) + x = self.projection(x) + x = self.down_sample(x) + return x + + +class Kandinsky3ResNetBlock(nn.Cell): + def __init__( + self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4 * [None] + ): + super().__init__() + kernel_sizes = [1, 3, 3, 1] + hidden_channel = max(in_channels, out_channels) // compression_ratio + hidden_channels = ( + [(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)] + ) + self.resnet_blocks = nn.CellList( + [ + Kandinsky3Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution) + for (in_channel, out_channel), kernel_size, up_resolution in zip( + hidden_channels, kernel_sizes, up_resolutions + ) + ] + ) + self.shortcut_up_sample = ( + nn.Conv2dTranspose(in_channels, in_channels, kernel_size=2, stride=2, pad_mode="pad", has_bias=True) + if True in up_resolutions + else nn.Identity() + ) + self.shortcut_projection = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, has_bias=True) + if in_channels != out_channels + else nn.Identity() + ) + self.shortcut_down_sample = ( + nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2, pad_mode="pad", has_bias=True) + if False in up_resolutions + else nn.Identity() + ) + + def construct(self, x, time_embed): + out = x + for resnet_block in self.resnet_blocks: + out = resnet_block(out, time_embed) + + x = self.shortcut_up_sample(x) + x = self.shortcut_projection(x) + x = self.shortcut_down_sample(x) + x = x + out + return x + + +class Kandinsky3AttentionPooling(nn.Cell): + def __init__(self, num_channels, context_dim, head_dim=64): + super().__init__() + self.attention = Attention( + context_dim, + context_dim, + dim_head=head_dim, + out_dim=num_channels, + out_bias=False, + ) + + def construct(self, x, context, context_mask=None): + if context_mask is not None: + context_mask = context_mask.to(dtype=context.dtype) + context = self.attention(context.mean(axis=1, keep_dims=True), context, context_mask) + return x + context.squeeze(1) + + +class Kandinsky3AttentionBlock(nn.Cell): + def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4): + super().__init__() + self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) + self.attention = Attention( + num_channels, + context_dim or num_channels, + dim_head=head_dim, + out_dim=num_channels, + out_bias=False, + ) + + hidden_channels = expansion_ratio * num_channels + self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) + self.feed_forward = nn.SequentialCell( + nn.Conv2d(num_channels, hidden_channels, kernel_size=1, has_bias=False), + SiLU(), + nn.Conv2d(hidden_channels, num_channels, kernel_size=1, has_bias=False), + ) + + def construct(self, x, time_embed, context=None, context_mask=None, image_mask=None): + height, width = x.shape[-2:] + out = self.in_norm(x, time_embed) + out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1) + context = context if context is not None else out + if context_mask is not None: + context_mask = context_mask.to(dtype=context.dtype) + + out = self.attention(out, context, context_mask) + out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width) + x = x + out + + out = self.out_norm(x, time_embed) + out = self.feed_forward(out) + x = x + out + return x diff --git a/mindone/diffusers/models/unets/unet_motion_model.py b/mindone/diffusers/models/unets/unet_motion_model.py new file mode 100644 index 0000000000..8ed0d055ad --- /dev/null +++ b/mindone/diffusers/models/unets/unet_motion_model.py @@ -0,0 +1,882 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple, Union + +import mindspore as ms +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import logging +from ..activations import SiLU +from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..normalization import GroupNorm +from ..transformers.transformer_temporal import TransformerTemporalModel +from .unet_2d_blocks import UNetMidBlock2DCrossAttn +from .unet_2d_condition import UNet2DConditionModel +from .unet_3d_blocks import ( + CrossAttnDownBlockMotion, + CrossAttnUpBlockMotion, + DownBlockMotion, + UNetMidBlockCrossAttnMotion, + UpBlockMotion, + get_down_block, + get_up_block, +) +from .unet_3d_condition import UNet3DConditionOutput + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class MotionModules(nn.Cell): + def __init__( + self, + in_channels: int, + layers_per_block: int = 2, + num_attention_heads: int = 8, + attention_bias: bool = False, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + norm_num_groups: int = 32, + max_seq_length: int = 32, + ): + super().__init__() + self.motion_modules = [] + + for i in range(layers_per_block): + self.motion_modules.append( + TransformerTemporalModel( + in_channels=in_channels, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels // num_attention_heads, + positional_embeddings="sinusoidal", + num_positional_embeddings=max_seq_length, + ) + ) + self.motion_modules = nn.CellList(self.motion_modules) + + +class MotionAdapter(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + motion_layers_per_block: int = 2, + motion_mid_block_layers_per_block: int = 1, + motion_num_attention_heads: int = 8, + motion_norm_num_groups: int = 32, + motion_max_seq_length: int = 32, + use_motion_mid_block: bool = True, + conv_in_channels: Optional[int] = None, + ): + """Container to store AnimateDiff Motion Modules + + Args: + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each UNet block. + motion_layers_per_block (`int`, *optional*, defaults to 2): + The number of motion layers per UNet block. + motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1): + The number of motion layers in the middle UNet block. + motion_num_attention_heads (`int`, *optional*, defaults to 8): + The number of heads to use in each attention layer of the motion module. + motion_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use in each group normalization layer of the motion module. + motion_max_seq_length (`int`, *optional*, defaults to 32): + The maximum sequence length to use in the motion module. + use_motion_mid_block (`bool`, *optional*, defaults to True): + Whether to use a motion module in the middle of the UNet. + """ + + super().__init__() + down_blocks = [] + up_blocks = [] + + if conv_in_channels: + # input + self.conv_in = nn.Conv2d( + conv_in_channels, block_out_channels[0], kernel_size=3, pad_mode="pad", padding=1, has_bias=True + ) + else: + self.conv_in = None + + for i, channel in enumerate(block_out_channels): + output_channel = block_out_channels[i] + down_blocks.append( + MotionModules( + in_channels=output_channel, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + layers_per_block=motion_layers_per_block, + ) + ) + + if use_motion_mid_block: + self.mid_block = MotionModules( + in_channels=block_out_channels[-1], + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=motion_num_attention_heads, + layers_per_block=motion_mid_block_layers_per_block, + max_seq_length=motion_max_seq_length, + ) + else: + self.mid_block = None + + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, channel in enumerate(reversed_block_out_channels): + output_channel = reversed_block_out_channels[i] + up_blocks.append( + MotionModules( + in_channels=output_channel, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + layers_per_block=motion_layers_per_block + 1, + ) + ) + + self.down_blocks = nn.CellList(down_blocks) + self.up_blocks = nn.CellList(up_blocks) + + def construct(self, sample): + pass + + +class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a + sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "DownBlockMotion", + ), + up_block_types: Tuple[str, ...] = ( + "UpBlockMotion", + "CrossAttnUpBlockMotion", + "CrossAttnUpBlockMotion", + "CrossAttnUpBlockMotion", + ), + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + use_linear_projection: bool = False, + num_attention_heads: Union[int, Tuple[int, ...]] = 8, + motion_max_seq_length: int = 32, + motion_num_attention_heads: int = 8, + use_motion_mid_block: int = True, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + ): + super().__init__() + + self.sample_size = sample_size + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. " + f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. " + f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. " + f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + pad_mode="pad", + padding=conv_in_padding, + has_bias=True, + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, time_embed_dim, act_fn=act_fn, cond_proj_dim=time_cond_proj_dim + ) + + if encoder_hid_dim_type is None: + self.encoder_hid_proj = None + + # class embedding + down_blocks = [] + up_blocks = [] + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + dual_cross_attention=False, + temporal_num_attention_heads=motion_num_attention_heads, + temporal_max_seq_length=motion_max_seq_length, + ) + down_blocks.append(down_block) + self.down_blocks = nn.CellList(down_blocks) + + # mid + if use_motion_mid_block: + self.mid_block = UNetMidBlockCrossAttnMotion( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + use_linear_projection=use_linear_projection, + temporal_num_attention_heads=motion_num_attention_heads, + temporal_max_seq_length=motion_max_seq_length, + ) + + else: + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + use_linear_projection=use_linear_projection, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + layers_per_resnet_in_up_blocks = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=False, + resolution_idx=i, + use_linear_projection=use_linear_projection, + temporal_num_attention_heads=motion_num_attention_heads, + temporal_max_seq_length=motion_max_seq_length, + ) + up_blocks.append(up_block) + prev_output_channel = output_channel + layers_per_resnet_in_up_blocks.append(len(up_block.resnets)) + self.up_blocks = nn.CellList(up_blocks) + self.layers_per_resnet_in_up_blocks = layers_per_resnet_in_up_blocks + + # out + if norm_num_groups is not None: + self.conv_norm_out = GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], + out_channels, + kernel_size=conv_out_kernel, + pad_mode="pad", + padding=conv_out_padding, + has_bias=True, + ) + + @classmethod + def from_unet2d( + cls, + unet: UNet2DConditionModel, + motion_adapter: Optional[MotionAdapter] = None, + load_weights: bool = True, + ): + has_motion_adapter = motion_adapter is not None + + # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459 + config = unet.config + config["_class_name"] = cls.__name__ + + down_blocks = [] + for down_blocks_type in config["down_block_types"]: + if "CrossAttn" in down_blocks_type: + down_blocks.append("CrossAttnDownBlockMotion") + else: + down_blocks.append("DownBlockMotion") + config["down_block_types"] = down_blocks + + up_blocks = [] + for down_blocks_type in config["up_block_types"]: + if "CrossAttn" in down_blocks_type: + up_blocks.append("CrossAttnUpBlockMotion") + else: + up_blocks.append("UpBlockMotion") + + config["up_block_types"] = up_blocks + + if has_motion_adapter: + config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] + config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] + config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] + + # For PIA UNets we need to set the number input channels to 9 + if motion_adapter.config["conv_in_channels"]: + config["in_channels"] = motion_adapter.config["conv_in_channels"] + + # Need this for backwards compatibility with UNet2DConditionModel checkpoints + if not config.get("num_attention_heads"): + config["num_attention_heads"] = config["attention_head_dim"] + + model = cls.from_config(config) + + # Move dtype conversion code here to avoid dtype mismatch issues when loading weights + # ensure that the Motion UNet is the same dtype as the UNet2DConditionModel + model.to(unet.dtype) + + if not load_weights: + return model + + # Logic for loading PIA UNets which allow the first 4 channels to be any UNet2DConditionModel conv_in weight + # while the last 5 channels must be PIA conv_in weights. + if has_motion_adapter and motion_adapter.config["conv_in_channels"]: + model.conv_in = motion_adapter.conv_in + updated_conv_in_weight = ops.cat([unet.conv_in.weight, motion_adapter.conv_in.weight[:, 4:, :, :]], axis=1) + ms.load_param_into_net(model.conv_in, {"weight": updated_conv_in_weight, "bias": unet.conv_in.bias}) + else: + ms.load_param_into_net(model.conv_in, unet.conv_in.parameters_dict()) + + ms.load_param_into_net(model.time_proj, unet.time_proj.parameters_dict()) + ms.load_param_into_net(model.time_embedding, unet.time_embedding.parameters_dict()) + + for i, down_block in enumerate(unet.down_blocks): + ms.load_param_into_net(model.down_blocks[i].resnets, down_block.resnets.parameters_dict()) + if hasattr(model.down_blocks[i], "attentions"): + ms.load_param_into_net(model.down_blocks[i].attentions, down_block.attentions.parameters_dict()) + if model.down_blocks[i].downsamplers: + ms.load_param_into_net(model.down_blocks[i].downsamplers, down_block.downsamplers.parameters_dict()) + + for i, up_block in enumerate(unet.up_blocks): + ms.load_param_into_net(model.up_blocks[i].resnets, up_block.resnets.parameters_dict()) + if hasattr(model.up_blocks[i], "attentions"): + ms.load_param_into_net(model.up_blocks[i].attentions, up_block.attentions.parameters_dict()) + if model.up_blocks[i].upsamplers: + ms.load_param_into_net(model.up_blocks[i].upsamplers, up_block.upsamplers.parameters_dict()) + + ms.load_param_into_net(model.mid_block.resnets, unet.mid_block.resnets.parameters_dict()) + ms.load_param_into_net(model.mid_block.attentions, unet.mid_block.attentions.parameters_dict()) + + if unet.conv_norm_out is not None: + ms.load_param_into_net(model.conv_norm_out, unet.conv_norm_out.parameters_dict()) + if unet.conv_act is not None: + ms.load_param_into_net(model.conv_act, unet.conv_act.parameters_dict()) + ms.load_param_into_net(model.conv_out, unet.conv_out.parameters_dict()) + + if has_motion_adapter: + model.load_motion_modules(motion_adapter) + + return model + + def freeze_unet2d_params(self) -> None: + """Freeze the weights of just the UNet2DConditionModel, and leave the motion modules + unfrozen for fine tuning. + """ + # Freeze everything + for param in self.get_parameters(): + param.requires_grad = False + + # Unfreeze Motion Modules + for down_block in self.down_blocks: + motion_modules = down_block.motion_modules + for param in motion_modules.get_parameters(): + param.requires_grad = True + + for up_block in self.up_blocks: + motion_modules = up_block.motion_modules + for param in motion_modules.get_parameters(): + param.requires_grad = True + + if hasattr(self.mid_block, "motion_modules"): + motion_modules = self.mid_block.motion_modules + for param in motion_modules.get_parameters(): + param.requires_grad = True + + def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None: + for i, down_block in enumerate(motion_adapter.down_blocks): + ms.load_param_into_net(self.down_blocks[i].motion_modules, down_block.motion_modules.parameters_dict()) + for i, up_block in enumerate(motion_adapter.up_blocks): + ms.load_param_into_net(self.up_blocks[i].motion_modules, up_block.motion_modules.parameters_dict()) + + # to support older motion modules that don't have a mid_block + if hasattr(self.mid_block, "motion_modules"): + ms.load_param_into_net( + self.mid_block.motion_modules, motion_adapter.mid_block.motion_modules.parameters_dict() + ) + + def save_motion_modules( + self, + save_directory: str, + is_main_process: bool = True, + safe_serialization: bool = True, + variant: Optional[str] = None, + push_to_hub: bool = False, + **kwargs, + ) -> None: + state_dict = self.parameters_dict() + + # Extract all motion modules + motion_state_dict = {} + for k, v in state_dict.items(): + if "motion_modules" in k: + motion_state_dict[k] = v + + adapter = MotionAdapter( + block_out_channels=self.config["block_out_channels"], + motion_layers_per_block=self.config["layers_per_block"], + motion_norm_num_groups=self.config["norm_num_groups"], + motion_num_attention_heads=self.config["motion_num_attention_heads"], + motion_max_seq_length=self.config["motion_max_seq_length"], + use_motion_mid_block=self.config["use_motion_mid_block"], + ) + ms.load_param_into_net(adapter, motion_state_dict) + adapter.save_pretrained( + save_directory=save_directory, + is_main_process=is_main_process, + safe_serialization=safe_serialization, + variant=variant, + push_to_hub=push_to_hub, + **kwargs, + ) + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: # type: ignore + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str, AttentionProcessor]): # type: ignore + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.name_cells().items(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.name_cells().items(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): # type: ignore + r""" + Sets the attention processor to use to compute attention. + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: nn.Cell, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.name_cells().items(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.name_cells().items(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.name_cells().items(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.name_cells().items(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self): + def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.name_cells().items(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.name_cells().items(): + fn_recursive_feed_forward(module, None, 0) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): + module.gradient_checkpointing = value + + def construct( + self, + sample: ms.Tensor, + timestep: Union[ms.Tensor, float, int], + encoder_hidden_states: ms.Tensor, + timestep_cond: Optional[ms.Tensor] = None, + attention_mask: Optional[ms.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, ms.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[ms.Tensor]] = None, + mid_block_additional_residual: Optional[ms.Tensor] = None, + return_dict: bool = False, + ) -> Union[UNet3DConditionOutput, Tuple[ms.Tensor]]: + r""" + The [`UNetMotionModel`] forward method. + + Args: + sample (`ms.Tensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. + timestep (`ms.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`ms.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + timestep_cond: (`ms.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`ms.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + down_block_additional_residuals: (`tuple` of `ms.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`ms.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if sample.shape[-2] % default_overall_up_factor != 0 or sample.shape[-1] % default_overall_up_factor != 0: + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not ops.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + if isinstance(timestep, float): + dtype = ms.float64 + else: + dtype = ms.int64 + timesteps = ms.Tensor([timesteps], dtype=dtype) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None] + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + num_frames = sample.shape[2] + timesteps = timesteps.broadcast_to((sample.shape[0],)) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + + if self.encoder_hid_proj is not None and self.config["encoder_hid_dim_type"] == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to " + f"'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds] + encoder_hidden_states = (encoder_hidden_states, image_embeds) + + # 2. pre-process + sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + sample = self.conv_in(sample) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames) + + down_block_res_samples += res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + # To support older versions of motion modules that don't have a mid_block + if self.mid_block.has_motion_modules: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-self.layers_per_resnet_in_up_blocks[i] :] + down_block_res_samples = down_block_res_samples[: -self.layers_per_resnet_in_up_blocks[i]] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, channel, framerate, width, height) + sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]).permute(0, 2, 1, 3, 4) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) diff --git a/mindone/diffusers/models/unets/unet_spatio_temporal_condition.py b/mindone/diffusers/models/unets/unet_spatio_temporal_condition.py new file mode 100644 index 0000000000..6b7c42073a --- /dev/null +++ b/mindone/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -0,0 +1,496 @@ +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import mindspore as ms +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import UNet2DConditionLoadersMixin +from ...utils import BaseOutput, logging +from ..activations import SiLU +from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from ..embeddings import TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from ..normalization import GroupNorm +from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNetSpatioTemporalConditionOutput(BaseOutput): + """ + The output of [`UNetSpatioTemporalConditionModel`]. + + Args: + sample (`ms.Tensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: ms.Tensor = None + + +class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to + `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to + `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + addition_time_embed_dim: (`int`, defaults to 256): + Dimension to to encode the additional time ids. + projection_class_embeddings_input_dim (`int`, defaults to 768): + The dimension of the projection of encoded `added_time_ids`. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], + [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. + num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): + The number of attention heads. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 8, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + ), + up_block_types: Tuple[str] = ( + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + ), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + addition_time_embed_dim: int = 256, + projection_class_embeddings_input_dim: int = 768, + layers_per_block: Union[int, Tuple[int]] = 2, + cross_attention_dim: Union[int, Tuple[int]] = 1024, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20), + num_frames: int = 25, + ): + super().__init__() + + self.sample_size = sample_size + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. " + f"`down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. " + f"`block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. " + f"`num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. " + f"`cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. " + f"`layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + + # input + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + pad_mode="pad", + padding=1, + has_bias=True, + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + blocks_time_embed_dim = time_embed_dim + + # down + self.down_blocks = [] + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=1e-5, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + resnet_act_fn="silu", + ) + self.down_blocks.append(down_block) + self.down_blocks = nn.CellList(self.down_blocks) + + # mid + self.mid_block = UNetMidBlockSpatioTemporal( + block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block[-1], + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + self.up_blocks = [] + layers_per_resnet_in_up_blocks = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=1e-5, + resolution_idx=i, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + resnet_act_fn="silu", + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + layers_per_resnet_in_up_blocks.append(len(up_block.resnets)) + self.up_blocks = nn.CellList(self.up_blocks) + self.layers_per_resnet_in_up_blocks = layers_per_resnet_in_up_blocks + + # out + self.conv_norm_out = GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) + self.conv_act = SiLU() + + self.conv_out = nn.Conv2d( + block_out_channels[0], + out_channels, + kernel_size=3, + pad_mode="pad", + padding=1, + has_bias=True, + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: # type: ignore + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str, AttentionProcessor]): # type: ignore + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.name_cells().items(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.name_cells().items(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): # type: ignore + r""" + Sets the attention processor to use to compute attention. + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: nn.Cell, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.name_cells().items(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.name_cells().items(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: nn.Cell, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.name_cells().items(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.name_cells().items(): + fn_recursive_feed_forward(module, chunk_size, dim) + + def construct( + self, + sample: ms.Tensor, + timestep: Union[ms.Tensor, float, int], + encoder_hidden_states: ms.Tensor, + added_time_ids: ms.Tensor, + return_dict: bool = False, + ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: + r""" + The [`UNetSpatioTemporalConditionModel`] forward method. + + Args: + sample (`ms.Tensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. + timestep (`ms.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`ms.Tensor`): + The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. + added_time_ids: (`ms.Tensor`): + The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal + embeddings and added to the time embeddings. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain + tuple. + Returns: + [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # 1. time + timesteps = timestep + if not ops.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + if isinstance(timestep, float): + dtype = ms.float64 + else: + dtype = ms.int64 + timesteps = ms.Tensor([timesteps], dtype=dtype) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None] + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + batch_size, num_frames = sample.shape[:2] + timesteps = timesteps.broadcast_to((batch_size,)) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb) + + time_embeds = self.add_time_proj(added_time_ids.flatten()) + time_embeds = time_embeds.reshape((batch_size, -1)) + time_embeds = time_embeds.to(emb.dtype) + aug_emb = self.add_embedding(time_embeds) + emb = emb + aug_emb + + # Flatten the batch and frames dimensions + # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] + sample = sample.flatten(start_dim=0, end_dim=1) + # Repeat the embeddings num_video_frames times + # emb: [batch, channels] -> [batch * frames, channels] + emb = emb.repeat_interleave(num_frames, dim=0) + # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] + encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) + + # 2. pre-process + sample = self.conv_in(sample) + + image_only_indicator = ops.zeros((batch_size, num_frames), dtype=sample.dtype) + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + image_only_indicator=image_only_indicator, + ) + + down_block_res_samples += res_samples + + # 4. mid + sample = self.mid_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-self.layers_per_resnet_in_up_blocks[i] :] + down_block_res_samples = down_block_res_samples[: -self.layers_per_resnet_in_up_blocks[i]] + + if upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + image_only_indicator=image_only_indicator, + ) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # 7. Reshape back to original shape + sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) + + if not return_dict: + return (sample,) + + return UNetSpatioTemporalConditionOutput(sample=sample) diff --git a/mindone/diffusers/models/unets/uvit_2d.py b/mindone/diffusers/models/unets/uvit_2d.py new file mode 100644 index 0000000000..156876a8fe --- /dev/null +++ b/mindone/diffusers/models/unets/uvit_2d.py @@ -0,0 +1,449 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Union + +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ..attention import BasicTransformerBlock, SkipFFTransformerBlock +from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor +from ..embeddings import TimestepEmbedding, get_timestep_embedding +from ..modeling_utils import ModelMixin +from ..normalization import GlobalResponseNorm, RMSNorm +from ..resnet import Downsample2D, Upsample2D + + +class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + # global config + hidden_size: int = 1024, + use_bias: bool = False, + hidden_dropout: float = 0.0, + # conditioning dimensions + cond_embed_dim: int = 768, + micro_cond_encode_dim: int = 256, + micro_cond_embed_dim: int = 1280, + encoder_hidden_size: int = 768, + # num tokens + vocab_size: int = 8256, # codebook_size + 1 (for the mask token) rounded + codebook_size: int = 8192, + # `UVit2DConvEmbed` + in_channels: int = 768, + block_out_channels: int = 768, + num_res_blocks: int = 3, + downsample: bool = False, + upsample: bool = False, + block_num_heads: int = 12, + # `TransformerLayer` + num_hidden_layers: int = 22, + num_attention_heads: int = 16, + # `Attention` + attention_dropout: float = 0.0, + # `FeedForward` + intermediate_size: int = 2816, + # `Norm` + layer_norm_eps: float = 1e-6, + ln_elementwise_affine: bool = True, + sample_size: int = 64, + ): + super().__init__() + + self.encoder_proj = nn.Dense(encoder_hidden_size, hidden_size, has_bias=use_bias) + self.encoder_proj_layer_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) + + self.embed = UVit2DConvEmbed( + in_channels, block_out_channels, vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias + ) + + self.cond_embed = TimestepEmbedding( + micro_cond_embed_dim + cond_embed_dim, hidden_size, sample_proj_bias=use_bias + ) + + self.down_block = UVitBlock( + block_out_channels, + num_res_blocks, + hidden_size, + hidden_dropout, + ln_elementwise_affine, + layer_norm_eps, + use_bias, + block_num_heads, + attention_dropout, + downsample, + False, + ) + + self.project_to_hidden_norm = RMSNorm(block_out_channels, layer_norm_eps, ln_elementwise_affine) + self.project_to_hidden = nn.Dense(block_out_channels, hidden_size, has_bias=use_bias) + + self.transformer_layers = nn.CellList( + [ + BasicTransformerBlock( + dim=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_dim=hidden_size // num_attention_heads, + dropout=hidden_dropout, + cross_attention_dim=hidden_size, + attention_bias=use_bias, + norm_type="ada_norm_continuous", + ada_norm_continous_conditioning_embedding_dim=hidden_size, + norm_elementwise_affine=ln_elementwise_affine, + norm_eps=layer_norm_eps, + ada_norm_bias=use_bias, + ff_inner_dim=intermediate_size, + ff_bias=use_bias, + attention_out_bias=use_bias, + ) + for _ in range(num_hidden_layers) + ] + ) + + self.project_from_hidden_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) + self.project_from_hidden = nn.Dense(hidden_size, block_out_channels, has_bias=use_bias) + + self.up_block = UVitBlock( + block_out_channels, + num_res_blocks, + hidden_size, + hidden_dropout, + ln_elementwise_affine, + layer_norm_eps, + use_bias, + block_num_heads, + attention_dropout, + downsample=False, + upsample=upsample, + ) + + self.mlm_layer = ConvMlmLayer( + block_out_channels, in_channels, use_bias, ln_elementwise_affine, layer_norm_eps, codebook_size + ) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + pass + + def construct(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None): + encoder_hidden_states = self.encoder_proj(encoder_hidden_states) + encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) + + micro_cond_embeds = get_timestep_embedding( + micro_conds.flatten(), self.config["micro_cond_encode_dim"], flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1)).to(pooled_text_emb.dtype) + + pooled_text_emb = ops.cat([pooled_text_emb, micro_cond_embeds], axis=1) + pooled_text_emb = pooled_text_emb.to(dtype=encoder_hidden_states.dtype) + pooled_text_emb = self.cond_embed(pooled_text_emb).to(encoder_hidden_states.dtype) + + hidden_states = self.embed(input_ids) + + hidden_states = self.down_block( + hidden_states, + pooled_text_emb=pooled_text_emb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels) + + hidden_states = self.project_to_hidden_norm(hidden_states) + hidden_states = self.project_to_hidden(hidden_states) + + for layer in self.transformer_layers: + hidden_states = layer( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs={"pooled_text_emb": pooled_text_emb}, + ) + + hidden_states = self.project_from_hidden_norm(hidden_states) + hidden_states = self.project_from_hidden(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + + hidden_states = self.up_block( + hidden_states, + pooled_text_emb=pooled_text_emb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + logits = self.mlm_layer(hidden_states) + + return logits + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: # type: ignore + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: nn.Cell, processors: Dict[str, AttentionProcessor]): # type: ignore + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.name_cells().items(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.name_cells().items(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): # type: ignore + r""" + Sets the attention processor to use to compute attention. + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: nn.Cell, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.name_cells().items(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.name_cells().items(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + +class UVit2DConvEmbed(nn.Cell): + def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias): + super().__init__() + self.embeddings = nn.Embedding(vocab_size, in_channels) + self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine) + self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, has_bias=bias) + + def construct(self, input_ids): + embeddings = self.embeddings(input_ids) + embeddings = self.layer_norm(embeddings) + embeddings = embeddings.permute(0, 3, 1, 2) + embeddings = self.conv(embeddings) + return embeddings + + +class UVitBlock(nn.Cell): + def __init__( + self, + channels, + num_res_blocks: int, + hidden_size, + hidden_dropout, + ln_elementwise_affine, + layer_norm_eps, + use_bias, + block_num_heads, + attention_dropout, + downsample: bool, + upsample: bool, + ): + super().__init__() + + if downsample: + self.downsample = Downsample2D( + channels, + use_conv=True, + padding=0, + name="Conv2d_0", + kernel_size=2, + norm_type="rms_norm", + eps=layer_norm_eps, + elementwise_affine=ln_elementwise_affine, + bias=use_bias, + ) + else: + self.downsample = None + + self.res_blocks = nn.CellList( + [ + ConvNextBlock( + channels, + layer_norm_eps, + ln_elementwise_affine, + use_bias, + hidden_dropout, + hidden_size, + ) + for i in range(num_res_blocks) + ] + ) + + self.attention_blocks = nn.CellList( + [ + SkipFFTransformerBlock( + channels, + block_num_heads, + channels // block_num_heads, + hidden_size, + use_bias, + attention_dropout, + channels, + attention_bias=use_bias, + attention_out_bias=use_bias, + ) + for _ in range(num_res_blocks) + ] + ) + + if upsample: + self.upsample = Upsample2D( + channels, + use_conv_transpose=True, + kernel_size=2, + padding=0, + name="conv", + norm_type="rms_norm", + eps=layer_norm_eps, + elementwise_affine=ln_elementwise_affine, + bias=use_bias, + interpolate=False, + ) + else: + self.upsample = None + + def construct(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs): + if self.downsample is not None: + x = self.downsample(x) + + for res_block, attention_block in zip(self.res_blocks, self.attention_blocks): + x = res_block(x, pooled_text_emb) + + batch_size, channels, height, width = x.shape + x = x.view(batch_size, channels, height * width).permute(0, 2, 1) + x = attention_block( + x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs + ) + x = x.permute(0, 2, 1).view(batch_size, channels, height, width) + + if self.upsample is not None: + x = self.upsample(x) + + return x + + +class ConvNextBlock(nn.Cell): + def __init__( + self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4 + ): + super().__init__() + self.depthwise = nn.Conv2d( + channels, + channels, + kernel_size=3, + pad_mode="pad", + padding=1, + group=channels, + has_bias=use_bias, + ) + self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine) + self.channelwise_linear_1 = nn.Dense(channels, int(channels * res_ffn_factor), has_bias=use_bias) + self.channelwise_act = nn.GELU() + self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor)) + self.channelwise_linear_2 = nn.Dense(int(channels * res_ffn_factor), channels, has_bias=use_bias) + self.channelwise_dropout = nn.Dropout(p=hidden_dropout) + self.cond_embeds_mapper = nn.Dense(hidden_size, channels * 2, has_bias=use_bias) + + def construct(self, x, cond_embeds): + x_res = x + + x = self.depthwise(x) + + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + + x = self.channelwise_linear_1(x) + x = self.channelwise_act(x) + x = self.channelwise_norm(x) + x = self.channelwise_linear_2(x) + x = self.channelwise_dropout(x) + + x = x.permute(0, 3, 1, 2) + + x = x + x_res + + scale, shift = self.cond_embeds_mapper(ops.silu(cond_embeds)).chunk(2, axis=1) + x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None] + + return x + + +class ConvMlmLayer(nn.Cell): + def __init__( + self, + block_out_channels: int, + in_channels: int, + use_bias: bool, + ln_elementwise_affine: bool, + layer_norm_eps: float, + codebook_size: int, + ): + super().__init__() + self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, has_bias=use_bias) + self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine) + self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, has_bias=use_bias) + + def construct(self, hidden_states): + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + logits = self.conv2(hidden_states) + return logits diff --git a/mindone/diffusers/models/vq_model.py b/mindone/diffusers/models/vq_model.py new file mode 100644 index 0000000000..acff1c73f6 --- /dev/null +++ b/mindone/diffusers/models/vq_model.py @@ -0,0 +1,174 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import mindspore as ms +from mindspore import nn + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer +from .modeling_utils import ModelMixin + + +@dataclass +class VQEncoderOutput(BaseOutput): + """ + Output of VQModel encoding method. + + Args: + latents (`ms.Tensor` of shape `(batch_size, num_channels, height, width)`): + The encoded output sample from the last layer of the model. + """ + + latents: ms.Tensor + + +class VQModel(ModelMixin, ConfigMixin): + r""" + A VQ-VAE model for decoding latent representations. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers. + vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. + scaling_factor (`float`, *optional*, defaults to `0.18215`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + norm_type (`str`, *optional*, defaults to `"group"`): + Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 3, + sample_size: int = 32, + num_vq_embeddings: int = 256, + norm_num_groups: int = 32, + vq_embed_dim: Optional[int] = None, + scaling_factor: float = 0.18215, + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + lookup_from_codebook=False, + force_upcast=False, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=False, + mid_block_add_attention=mid_block_add_attention, + ) + + vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels + + self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1, has_bias=True) + self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) + self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1, has_bias=True) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + norm_type=norm_type, + mid_block_add_attention=mid_block_add_attention, + ) + + def encode(self, x: ms.Tensor, return_dict: bool = False): + h = self.encoder(x) + h = self.quant_conv(h) + + if not return_dict: + return (h,) + + return VQEncoderOutput(latents=h) + + def decode(self, h: ms.Tensor, force_not_quantize: bool = False, return_dict: bool = False, shape=None): + # also go through quantization layer + if not force_not_quantize: + quant, _, _ = self.quantize(h) + elif self.config["lookup_from_codebook"]: + quant = self.quantize.get_codebook_entry(h, shape) + else: + quant = h + quant2 = self.post_quant_conv(quant) + dec = self.decoder(quant2, quant if self.config["norm_type"] == "spatial" else None) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def construct(self, sample: ms.Tensor, return_dict: bool = False) -> Union[DecoderOutput, Tuple[ms.Tensor, ...]]: + r""" + The [`VQModel`] forward method. + + Args: + sample (`ms.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vq_model.VQEncoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` + is returned. + """ + + h = self.encode(sample)[0] + dec = self.decode(h)[0] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/tests/diffusers/models/test_layers.py b/tests/diffusers/models/modeling_test_utils.py similarity index 54% rename from tests/diffusers/models/test_layers.py rename to tests/diffusers/models/modeling_test_utils.py index a3ed2f1eb1..10f27f1690 100644 --- a/tests/diffusers/models/test_layers.py +++ b/tests/diffusers/models/modeling_test_utils.py @@ -1,34 +1,14 @@ import importlib -import inspect import logging import numpy as np -import pytest import torch from diffusers.utils import BaseOutput import mindspore as ms from mindspore import nn, ops -from .test_layers_cases import ALL_CASES - -logger = logging.getLogger("ModulesUnitTest") - - -THRESHOLD_FP16 = 1e-2 -THRESHOLD_FP32 = 1e-3 - - -PT_DTYPE_MAPPING = { - "fp16": torch.float16, - "fp32": torch.float32, -} - - -MS_DTYPE_MAPPING = { - "fp16": ms.float16, - "fp32": ms.float32, -} +logger = logging.getLogger("ModelingsUnitTest") TORCH_FP16_BLACKLIST = ( @@ -40,6 +20,7 @@ "FirUpsample2D", "FirDownsample2D", "KDownsample2D", + "AutoencoderTiny", ) @@ -63,13 +44,13 @@ def get_pt2ms_mappings(m): return mappings -# copied from mindone.diffusers.models.modeling_utils +# adapted from mindone.diffusers.models.modeling_utils def convert_state_dict(m, state_dict_pt): mappings = get_pt2ms_mappings(m) state_dict_ms = {} for name_pt, data_pt in state_dict_pt.items(): name_ms, data_mapping = mappings.get(name_pt, (name_pt, lambda x: x)) - data_ms = data_mapping(ms.Parameter(ms.Tensor.from_numpy(data_pt.numpy()))) + data_ms = ms.Parameter(data_mapping(ms.Tensor.from_numpy(data_pt.numpy()))) if name_ms is not None: state_dict_ms[name_ms] = data_ms return state_dict_ms @@ -130,7 +111,7 @@ def set_dtype(model, dtype): return model -def common_parse_args(pt_dtype, ms_dtype, *args, **kwargs): +def generalized_parse_args(pt_dtype, ms_dtype, *args, **kwargs): dtype_mappings = { "fp32": np.float32, "fp16": np.float16, @@ -196,159 +177,3 @@ def compute_diffs(pt_outputs: torch.Tensor, ms_outputs: ms.Tensor): diffs.append(d) return diffs - - -@pytest.mark.parametrize( - "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs", - ALL_CASES, -) -def test_named_modules_with_graph_fp32( - name, - pt_module, - ms_module, - init_args, - init_kwargs, - inputs_args, - inputs_kwargs, -): - dtype = "fp32" - ms.set_context(mode=ms.GRAPH_MODE, jit_syntax_level=ms.STRICT) - - ( - pt_model, - ms_model, - pt_dtype, - ms_dtype, - ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) - pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = common_parse_args( - pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs - ) - - with torch.no_grad(): - pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) - ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) - - diffs = compute_diffs(pt_outputs, ms_outputs) - - assert ( - np.array(diffs) < THRESHOLD_FP32 - ).all(), f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD_FP32}" - - -@pytest.mark.parametrize( - "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs", - ALL_CASES, -) -def test_named_modules_with_graph_fp16( - name, - pt_module, - ms_module, - init_args, - init_kwargs, - inputs_args, - inputs_kwargs, -): - dtype = "fp16" - ms.set_context(mode=ms.GRAPH_MODE, jit_syntax_level=ms.STRICT) - - ( - pt_model, - ms_model, - pt_dtype, - ms_dtype, - ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) - pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = common_parse_args( - pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs - ) - - if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: - pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) - ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) - - with torch.no_grad(): - pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) - ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) - - diffs = compute_diffs(pt_outputs, ms_outputs) - - assert ( - np.array(diffs) < THRESHOLD_FP16 - ).all(), f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD_FP16}" - - -@pytest.mark.parametrize( - "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs", - ALL_CASES, -) -def test_named_modules_with_pynative_fp32( - name, - pt_module, - ms_module, - init_args, - init_kwargs, - inputs_args, - inputs_kwargs, -): - dtype = "fp32" - ms.set_context(mode=ms.PYNATIVE_MODE, jit_syntax_level=ms.STRICT) - - ( - pt_model, - ms_model, - pt_dtype, - ms_dtype, - ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) - pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = common_parse_args( - pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs - ) - - with torch.no_grad(): - pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) - ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) - - diffs = compute_diffs(pt_outputs, ms_outputs) - - assert ( - np.array(diffs) < THRESHOLD_FP32 - ).all(), f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD_FP32}" - - -@pytest.mark.parametrize( - "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs", - ALL_CASES, -) -def test_named_modules_with_pynative_fp16( - name, - pt_module, - ms_module, - init_args, - init_kwargs, - inputs_args, - inputs_kwargs, -): - dtype = "fp16" - ms.set_context(mode=ms.PYNATIVE_MODE, jit_syntax_level=ms.STRICT) - - ( - pt_model, - ms_model, - pt_dtype, - ms_dtype, - ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) - pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = common_parse_args( - pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs - ) - - if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: - pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) - ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) - - with torch.no_grad(): - pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) - ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) - - diffs = compute_diffs(pt_outputs, ms_outputs) - - assert ( - np.array(diffs) < THRESHOLD_FP16 - ).all(), f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD_FP16}" diff --git a/tests/diffusers/models/test_layers_cases.py b/tests/diffusers/models/modules_test_cases.py similarity index 55% rename from tests/diffusers/models/test_layers_cases.py rename to tests/diffusers/models/modules_test_cases.py index 091543c5ef..f2b175c037 100644 --- a/tests/diffusers/models/test_layers_cases.py +++ b/tests/diffusers/models/modules_test_cases.py @@ -1,3 +1,20 @@ +# This file defined a list containing all generalized test cases. Each test case is represented as a list or tuple following the structure: +# [name, pt_module, ms_module, init_args, init_kwargs, inputs_args, inputs_kwargs]. +# +# Parameters: +# name: +# A string identifier for the test case, primarily for diagnostic purposes and not utilized during execution. +# ms_module: +# The module from 'mindone.diffusers' under test, e.g., 'mindone.diffusers.models.model0'. +# pt_module: +# The counterpart module from the original 'diffusers' library for accuracy benchmarking, matching ms_module in functionality. +# init_args, init_kwargs: +# Arguments for initializing the modules, positional and keyword respectively. +# inputs_args, inputs_kwargs: +# Arguments for model inputs, positional and keyword respectively. These are initially defined with numpy for compatibility and are converted +# to PyTorch or MindSpore formats via the `.modeling_test_utils.generalized_parse_args` utility. If this utility's conversions do not suffice, +# a specific unit test should be developed rather than relying on generic test cases. + import numpy as np # layers @@ -426,6 +443,220 @@ ] +LAYERS_CASES = ( + NORMALIZATION_CASES + EMBEDDINGS_CASES + UPSAMPLE2D_CASES + DOWNSAMPLE2D_CASES + RESNET_CASES + T2I_ADAPTER_CASES +) + + +# autoencoders +# VQModel: volatile in fp16(fyi: 2%-20% diff when torch.fp16 vs torch.fp32) +VQ_CASES = [ + [ + "VQModel", # volatile with random init: 2%-20% diff when torch.float16 v.s. torch.float32 + "diffusers.models.vq_model.VQModel", + "mindone.diffusers.models.vq_model.VQModel", + (), + { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "latent_channels": 3, + }, + (), + {"sample": np.random.randn(4, 3, 32, 32).astype(np.float32), "return_dict": False}, + ], +] + + +VAE_CASES = [ + [ + "AutoencoderKL", + "diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL", + "mindone.diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL", + (), + { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D"] * len([32, 64]), + "up_block_types": ["UpDecoderBlock2D"] * len([32, 64]), + "latent_channels": 4, + "norm_num_groups": 32, + }, + (), + { + "sample": np.random.randn(4, 3, 32, 32).astype(np.float32), + "return_dict": False, + }, + ], + [ + "AsymmetricAutoencoderKL", + "diffusers.models.autoencoders.autoencoder_asym_kl.AsymmetricAutoencoderKL", + "mindone.diffusers.models.autoencoders.autoencoder_asym_kl.AsymmetricAutoencoderKL", + (), + { + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D"] * len([32, 64]), + "down_block_out_channels": [32, 64], + "layers_per_down_block": 1, + "up_block_types": ["UpDecoderBlock2D"] * len([32, 64]), + "up_block_out_channels": [32, 64], + "layers_per_up_block": 1, + "act_fn": "silu", + "latent_channels": 4, + "norm_num_groups": 32, + "sample_size": 32, + "scaling_factor": 0.18215, + }, + (), + { + "sample": np.random.randn(4, 3, 32, 32).astype(np.float32), + "mask": np.random.randn(4, 1, 32, 32).astype(np.float32), + "return_dict": False, + }, + ], + [ + "AutoencoderTiny", + "diffusers.models.autoencoders.autoencoder_tiny.AutoencoderTiny", + "mindone.diffusers.models.autoencoders.autoencoder_tiny.AutoencoderTiny", + (), + { + "in_channels": 3, + "out_channels": 3, + "encoder_block_out_channels": [32, 32], + "decoder_block_out_channels": [32, 32], + "num_encoder_blocks": [b // min([32, 32]) for b in [32, 32]], + "num_decoder_blocks": [b // min([32, 32]) for b in reversed([32, 32])], + }, + (), + { + "sample": np.random.randn(4, 3, 32, 32).astype(np.float32), + "return_dict": False, + }, + ], + [ + "AutoencoderKLTemporalDecoder", + "diffusers.models.autoencoders.autoencoder_kl_temporal_decoder.AutoencoderKLTemporalDecoder", + "mindone.diffusers.models.autoencoders.autoencoder_kl_temporal_decoder.AutoencoderKLTemporalDecoder", + (), + { + "block_out_channels": [32, 64], + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "latent_channels": 4, + "layers_per_block": 2, + }, + (), + { + "sample": np.random.randn(3, 3, 32, 32).astype(np.float32), + "num_frames": 3, + "return_dict": False, + }, + ], +] + + +AE_CASES = VQ_CASES + VAE_CASES + + +# transformers +TRANSFORMER2D_CASES = [ + [ + "SpatialTransformer2DModel_default", + "diffusers.models.transformers.transformer_2d.Transformer2DModel", + "mindone.diffusers.models.transformers.transformer_2d.Transformer2DModel", + (), + dict( + in_channels=32, + num_attention_heads=1, + attention_head_dim=32, + dropout=0.0, + cross_attention_dim=None, + ), + (np.random.randn(1, 32, 64, 64).astype(np.float32),), + dict(return_dict=False), + ], + [ + "SpatialTransformer2DModel_cross_attention_dim", + "diffusers.models.transformers.transformer_2d.Transformer2DModel", + "mindone.diffusers.models.transformers.transformer_2d.Transformer2DModel", + (), + dict( + in_channels=64, + num_attention_heads=2, + attention_head_dim=32, + dropout=0.0, + cross_attention_dim=64, + ), + (np.random.randn(1, 64, 64, 64).astype(np.float32), np.random.randn(1, 4, 64).astype(np.float32)), + dict(return_dict=False), + ], + [ + "SpatialTransformer2DModel_dropout", + "diffusers.models.transformers.transformer_2d.Transformer2DModel", + "mindone.diffusers.models.transformers.transformer_2d.Transformer2DModel", + (), + dict( + in_channels=32, + num_attention_heads=2, + attention_head_dim=16, + dropout=0.3, + cross_attention_dim=None, + ), + (np.random.randn(1, 32, 64, 64).astype(np.float32),), + dict(return_dict=False), + ], + [ + "SpatialTransformer2DModel_discrete", + "diffusers.models.transformers.transformer_2d.Transformer2DModel", + "mindone.diffusers.models.transformers.transformer_2d.Transformer2DModel", + (), + dict( + num_attention_heads=1, + attention_head_dim=32, + num_vector_embeds=5, + sample_size=16, + ), + (np.random.randint(0, 5, (1, 32)).astype(np.int64),), + dict(return_dict=False), + ], +] + + +PRIOR_TRANSFORMER_CASES = [ + [ + "PriorTransformer", + "diffusers.models.transformers.prior_transformer.PriorTransformer", + "mindone.diffusers.models.transformers.prior_transformer.PriorTransformer", + (), + { + "num_attention_heads": 2, + "attention_head_dim": 4, + "num_layers": 2, + "embedding_dim": 8, + "num_embeddings": 7, + "additional_embeddings": 4, + }, + (), + { + "hidden_states": np.random.randn(4, 8).astype(np.float32), + "timestep": 2, + "proj_embedding": np.random.randn(4, 8).astype(np.float32), + "encoder_hidden_states": np.random.randn(4, 7, 8).astype(np.float32), + "return_dict": False, + }, + ], +] + + +TRANSFORMERS_CASES = TRANSFORMER2D_CASES + PRIOR_TRANSFORMER_CASES + + +# unet UNET1D_CASES = [ [ "UNet1DModel", @@ -484,6 +715,210 @@ ] +UVIT2D_CASES = [ + [ + "UVit2DModel", + "diffusers.models.unets.uvit_2d.UVit2DModel", + "mindone.diffusers.models.unets.uvit_2d.UVit2DModel", + (), + dict( + hidden_size=32, + use_bias=False, + hidden_dropout=0.0, + cond_embed_dim=32, + micro_cond_encode_dim=2, + micro_cond_embed_dim=10, + encoder_hidden_size=32, + vocab_size=32, + codebook_size=32, + in_channels=32, + block_out_channels=32, + num_res_blocks=1, + downsample=True, + upsample=True, + block_num_heads=1, + num_hidden_layers=1, + num_attention_heads=1, + attention_dropout=0.0, + intermediate_size=32, + layer_norm_eps=1e-06, + ln_elementwise_affine=True, + ), + (), + { + "input_ids": np.random.randint(0, 32, (2, 4, 4)).astype(np.int64), + "encoder_hidden_states": np.random.randn(2, 77, 32).astype(np.float32), + "pooled_text_emb": np.random.randn(2, 32).astype(np.float32), + "micro_conds": np.random.randn(2, 5).astype(np.float32), + }, + ], +] + + +KANDINSKY3_CASES = [ + [ + "Kandinsky3UNet", + "diffusers.models.unets.unet_kandinsky3.Kandinsky3UNet", + "mindone.diffusers.models.unets.unet_kandinsky3.Kandinsky3UNet", + (), + dict( + in_channels=4, + time_embedding_dim=4, + groups=2, + attention_head_dim=4, + layers_per_block=3, + block_out_channels=(32, 64), + cross_attention_dim=4, + encoder_hid_dim=32, + ), + (), + { + "sample": np.random.randn(2, 4, 8, 8).astype(np.float32), + "timestep": np.array([10]).astype(np.int64), + "encoder_hidden_states": np.random.randn(2, 36, 32).astype(np.float32), + "encoder_attention_mask": np.ones((2, 36)).astype(np.float32), + "return_dict": False, + }, + ], +] + + +UNET3D_CONDITION_MODEL_CASES = [ + [ + "UNet3DConditionModel", + "diffusers.models.unets.unet_3d_condition.UNet3DConditionModel", + "mindone.diffusers.models.unets.unet_3d_condition.UNet3DConditionModel", + (), + { + "block_out_channels": (32, 64), + "down_block_types": ( + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + "up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D"), + "cross_attention_dim": 32, + "attention_head_dim": 8, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 1, + "sample_size": 32, + "norm_num_groups": 32, + }, + (), + { + "sample": np.random.randn(4, 4, 4, 32, 32).astype(np.float32), + "timestep": np.array([10]).astype(np.int64), + "encoder_hidden_states": np.random.randn(4, 4, 32).astype(np.float32), + "return_dict": False, + }, + ], +] + + +UNET_SPATIO_TEMPORAL_CONDITION_MODEL_CASES = [ + [ + "UNetSpatioTemporalConditionModel", + "diffusers.models.unets.unet_spatio_temporal_condition.UNetSpatioTemporalConditionModel", + "mindone.diffusers.models.unets.unet_spatio_temporal_condition.UNetSpatioTemporalConditionModel", + (), + { + "block_out_channels": (32, 64), + "down_block_types": ( + "CrossAttnDownBlockSpatioTemporal", + "DownBlockSpatioTemporal", + ), + "up_block_types": ( + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", + ), + "cross_attention_dim": 32, + "num_attention_heads": 8, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 2, + "sample_size": 32, + "projection_class_embeddings_input_dim": 32 * 3, + "addition_time_embed_dim": 32, + }, + (), + { + "sample": np.random.randn(2, 2, 4, 32, 32).astype(np.float32), + "timestep": np.array([10]).astype(np.int64), + "encoder_hidden_states": np.random.randn(2, 1, 32).astype(np.float32), + "added_time_ids": np.array([[6, 127, 0.02], [6, 127, 0.02]]).astype(np.float32), + "return_dict": False, + }, + ], +] + + +UNET_I2VGEN_XL_CASES = [ + [ + "I2VGenXLUNet", + "diffusers.models.unets.unet_i2vgen_xl.I2VGenXLUNet", + "mindone.diffusers.models.unets.unet_i2vgen_xl.I2VGenXLUNet", + (), + dict( + sample_size=None, + in_channels=4, + out_channels=4, + down_block_types=( + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types=( + "UpBlock3D", + "CrossAttnUpBlock3D", + ), + block_out_channels=(32, 64), + layers_per_block=2, + norm_num_groups=4, + cross_attention_dim=32, + attention_head_dim=4, + num_attention_heads=None, + ), + (), + { + "sample": np.random.randn(2, 4, 2, 32, 32).astype(np.float32), + "timestep": np.array([10]).astype(np.int64), + "fps": np.array([2]).astype(np.int64), + "image_latents": np.random.randn(2, 4, 2, 32, 32).astype(np.float32), + "image_embeddings": np.random.randn(2, 32).astype(np.float32), + "encoder_hidden_states": np.random.randn(2, 1, 32).astype(np.float32), + "return_dict": False, + }, + ], +] + + +UNET_MOTION_MODEL_TEST = [ + [ + "UNetMotionModel", + "diffusers.models.unets.unet_motion_model.UNetMotionModel", + "mindone.diffusers.models.unets.unet_motion_model.UNetMotionModel", + (), + { + "block_out_channels": (32, 64), + "down_block_types": ("CrossAttnDownBlockMotion", "DownBlockMotion"), + "up_block_types": ("UpBlockMotion", "CrossAttnUpBlockMotion"), + "cross_attention_dim": 32, + "num_attention_heads": 4, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 1, + "sample_size": 32, + }, + (), + { + "sample": np.random.randn(4, 4, 8, 32, 32).astype(np.float32), + "timestep": np.array([10]).astype(np.int64), + "encoder_hidden_states": np.random.randn(4, 4, 32).astype(np.float32), + "return_dict": False, + }, + ], +] + + UNETSTABLECASCADE_CASES = [ [ "UNetStableCascadeModel_prior", @@ -577,14 +1012,17 @@ ] -# CONTROL_NET_CASES: outputs format isn't same with others -ALL_CASES = ( - NORMALIZATION_CASES - + EMBEDDINGS_CASES - + UPSAMPLE2D_CASES - + DOWNSAMPLE2D_CASES - + RESNET_CASES - + T2I_ADAPTER_CASES - + UNET1D_CASES +UNETS_CASES = ( + UNET1D_CASES + + UVIT2D_CASES + + KANDINSKY3_CASES + + UNET3D_CONDITION_MODEL_CASES + + UNET_SPATIO_TEMPORAL_CONDITION_MODEL_CASES + + UNET_I2VGEN_XL_CASES + + UNET_MOTION_MODEL_TEST + UNETSTABLECASCADE_CASES ) + + +# all +ALL_CASES = LAYERS_CASES + AE_CASES + TRANSFORMERS_CASES + UNETS_CASES diff --git a/tests/diffusers/models/test_generic_modules.py b/tests/diffusers/models/test_generic_modules.py new file mode 100644 index 0000000000..d7d7c1907f --- /dev/null +++ b/tests/diffusers/models/test_generic_modules.py @@ -0,0 +1,191 @@ +# This module contains test cases that are defined in the `.test_cases.py` file, structured as lists or tuples like +# [name, pt_module, ms_module, init_args, init_kwargs, inputs_args, inputs_kwargs]. +# +# Each defined case corresponds to a pair consisting of PyTorch and MindSpore modules, including their respective +# initialization parameters and inputs for the forward. The testing framework adopted here is designed to generically +# parse these parameters to assess and compare the precision of forward outcomes between the two frameworks. +# +# In cases where models have unique initialization procedures or require testing with specialized output formats, +# it is necessary to develop distinct, dedicated test cases. + +import inspect + +import numpy as np +import pytest +import torch + +import mindspore as ms + +from .modeling_test_utils import compute_diffs, generalized_parse_args, get_modules +from .modules_test_cases import ALL_CASES + +THRESHOLD_FP16 = 1e-2 +THRESHOLD_FP32 = 5e-3 + + +PT_DTYPE_MAPPING = { + "fp16": torch.float16, + "fp32": torch.float32, +} + + +MS_DTYPE_MAPPING = { + "fp16": ms.float16, + "fp32": ms.float32, +} + + +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs", + ALL_CASES, +) +def test_named_modules_with_graph_fp32( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, +): + dtype = "fp32" + ms.set_context(mode=ms.GRAPH_MODE, jit_syntax_level=ms.STRICT) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + + diffs = compute_diffs(pt_outputs, ms_outputs) + + assert ( + np.array(diffs) < THRESHOLD_FP32 + ).all(), f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD_FP32}" + + +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs", + ALL_CASES, +) +def test_named_modules_with_graph_fp16( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, +): + dtype = "fp16" + ms.set_context(mode=ms.GRAPH_MODE, jit_syntax_level=ms.STRICT) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + + diffs = compute_diffs(pt_outputs, ms_outputs) + + assert ( + np.array(diffs) < THRESHOLD_FP16 + ).all(), f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD_FP16}" + + +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs", + ALL_CASES, +) +def test_named_modules_with_pynative_fp32( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, +): + dtype = "fp32" + ms.set_context(mode=ms.PYNATIVE_MODE, jit_syntax_level=ms.STRICT) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + + diffs = compute_diffs(pt_outputs, ms_outputs) + + assert ( + np.array(diffs) < THRESHOLD_FP32 + ).all(), f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD_FP32}" + + +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs", + ALL_CASES, +) +def test_named_modules_with_pynative_fp16( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, +): + dtype = "fp16" + ms.set_context(mode=ms.PYNATIVE_MODE, jit_syntax_level=ms.STRICT) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + + diffs = compute_diffs(pt_outputs, ms_outputs) + + assert ( + np.array(diffs) < THRESHOLD_FP16 + ).all(), f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD_FP16}" diff --git a/tests/diffusers/models/test_transformers.py b/tests/diffusers/models/test_transformers.py new file mode 100644 index 0000000000..5f3b178f24 --- /dev/null +++ b/tests/diffusers/models/test_transformers.py @@ -0,0 +1,71 @@ +import numpy as np +import pytest +import torch + +import mindspore as ms + +from .modeling_test_utils import compute_diffs, generalized_parse_args, get_modules + +THRESHOLD_FP16 = 1e-2 +THRESHOLD_FP32 = 5e-3 + + +@pytest.mark.parametrize( + "name,mode,dtype", + [ + ["T5FilmDecoder_graph_fp32", 0, "fp32"], + ["T5FilmDecoder_graph_fp16", 0, "fp16"], + ["T5FilmDecoder_pynative_fp32", 1, "fp32"], + ["T5FilmDecoder_pynative_fp16", 1, "fp16"], + ], +) +def test_t5_film_decoder(name, mode, dtype): + ms.set_context(mode=mode, jit_syntax_level=ms.STRICT) + + # init model + pt_module = "diffusers.models.transformers.t5_film_transformer.T5FilmDecoder" + ms_module = f"mindone.{pt_module}" + + init_args = () + init_kwargs = { + "input_dims": 32, + "d_model": 64, + "num_heads": 4, + "d_ff": 64, + "targets_length": 8, + } + + pt_model, ms_model, pt_dtype, ms_dtype = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + + # get inputs + inputs_args = ( + np.random.randn(2, 64, 64).astype(np.float32), + np.random.randn(2, 64).astype(np.int32), + np.random.randn(2, 8, 64).astype(np.float32), + np.random.randn(2, 8).astype(np.int32), + ) + inputs_kwargs = { + "decoder_input_tokens": np.random.randn(2, 8, 32).astype(np.float32), + "decoder_noise_time": np.array([0.99, 0.50]).astype(np.float32), + } + + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + pt_inputs_kwargs["encodings_and_masks"] = ( + (pt_inputs_args[0], pt_inputs_args[1]), + (pt_inputs_args[2], pt_inputs_args[3]), + ) + ms_inputs_kwargs["encodings_and_masks"] = ( + (ms_inputs_args[0], ms_inputs_args[1]), + (ms_inputs_args[2], ms_inputs_args[3]), + ) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs_kwargs) + ms_outputs = ms_model(**ms_inputs_kwargs) + + diffs = compute_diffs(pt_outputs, ms_outputs) + + eps = THRESHOLD_FP16 if dtype == "fp16" else THRESHOLD_FP32 + assert (np.array(diffs) < eps).all(), f"Outputs({np.array(diffs).tolist()}) has diff bigger than {eps}"