From 6271bc5fcff4fa8505660d592edf52bd92d817be Mon Sep 17 00:00:00 2001 From: Cui-yshoho <73014084+Cui-yshoho@users.noreply.github.com> Date: Wed, 19 Jun 2024 10:49:30 +0800 Subject: [PATCH] feat(diffusers/models/unets): add unet1d & StableCascade (#519) --- mindone/diffusers/README.md | 13 +- mindone/diffusers/__init__.py | 12 +- mindone/diffusers/models/__init__.py | 4 +- mindone/diffusers/models/modeling_utils.py | 6 +- mindone/diffusers/models/unets/__init__.py | 2 + mindone/diffusers/models/unets/unet_1d.py | 261 ++++++ .../diffusers/models/unets/unet_1d_blocks.py | 746 ++++++++++++++++++ .../models/unets/unet_stable_cascade.py | 619 +++++++++++++++ tests/diffusers/models/test_layers.py | 6 +- tests/diffusers/models/test_layers_cases.py | 160 +++- 10 files changed, 1821 insertions(+), 8 deletions(-) create mode 100644 mindone/diffusers/models/unets/unet_1d.py create mode 100644 mindone/diffusers/models/unets/unet_1d_blocks.py create mode 100644 mindone/diffusers/models/unets/unet_stable_cascade.py diff --git a/mindone/diffusers/README.md b/mindone/diffusers/README.md index 0115b45c98..dcb86ea3d9 100644 --- a/mindone/diffusers/README.md +++ b/mindone/diffusers/README.md @@ -97,9 +97,20 @@ Most base, utility and mixin class are available. - [ ] StableDiffusionPipeline ### Model + +#### AutoEncoders + - [x] AutoencoderKL -- [x] Transformer2DModel + +#### UNets + +- [x] UNet1DModel - [x] UNet2DConditionModel +- [x] StableCascadeUNet + +#### Transformers + +- [x] Transformer2DModel ### Scheduler - [x] DDIMScheduler/DDPMScheduler/...(30) diff --git a/mindone/diffusers/__init__.py b/mindone/diffusers/__init__.py index b6a93bf58b..d3024f5efd 100644 --- a/mindone/diffusers/__init__.py +++ b/mindone/diffusers/__init__.py @@ -17,8 +17,10 @@ "AutoencoderKL", "ModelMixin", "SD3Transformer2DModel", + "UNet1DModel", "UNet2DConditionModel", "UNet2DModel", + "StableCascadeUNet", ], "optimization": [ "get_constant_schedule", @@ -77,7 +79,15 @@ if TYPE_CHECKING: from .configuration_utils import ConfigMixin - from .models import AutoencoderKL, ModelMixin, SD3Transformer2DModel, UNet2DConditionModel, UNet2DModel + from .models import ( + AutoencoderKL, + ModelMixin, + SD3Transformer2DModel, + StableCascadeUNet, + UNet1DModel, + UNet2DConditionModel, + UNet2DModel, + ) from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, diff --git a/mindone/diffusers/models/__init__.py b/mindone/diffusers/models/__init__.py index 9ceb6e47f5..9440cbb1af 100644 --- a/mindone/diffusers/models/__init__.py +++ b/mindone/diffusers/models/__init__.py @@ -24,8 +24,10 @@ "modeling_utils": ["ModelMixin"], "transformers.transformer_2d": ["Transformer2DModel"], "transformers.transformer_sd3": ["SD3Transformer2DModel"], + "unets.unet_1d": ["UNet1DModel"], "unets.unet_2d": ["UNet2DModel"], "unets.unet_2d_condition": ["UNet2DConditionModel"], + "unets.unet_stable_cascade": ["StableCascadeUNet"], } if TYPE_CHECKING: @@ -35,7 +37,7 @@ from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import SD3Transformer2DModel, Transformer2DModel - from .unets import UNet2DConditionModel, UNet2DModel + from .unets import StableCascadeUNet, UNet1DModel, UNet2DConditionModel, UNet2DModel else: import sys diff --git a/mindone/diffusers/models/modeling_utils.py b/mindone/diffusers/models/modeling_utils.py index 97971d452e..1a7e7c7bac 100644 --- a/mindone/diffusers/models/modeling_utils.py +++ b/mindone/diffusers/models/modeling_utils.py @@ -46,8 +46,10 @@ def _get_pt2ms_mappings(m): mappings = {} # pt_param_name: (ms_param_name, pt_param_to_ms_param_func) for name, cell in m.cells_and_names(): - if isinstance(cell, nn.Conv1d): - mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ops.expand_dims(x, axis=-2) + if isinstance(cell, (nn.Conv1d, nn.Conv1dTranspose)): + mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ms.Parameter( + ops.expand_dims(x, axis=-2), name=x.name + ) elif isinstance(cell, nn.Embedding): mappings[f"{name}.weight"] = f"{name}.embedding_table", lambda x: x elif isinstance(cell, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)): diff --git a/mindone/diffusers/models/unets/__init__.py b/mindone/diffusers/models/unets/__init__.py index 80f227d070..5e2a1bd011 100644 --- a/mindone/diffusers/models/unets/__init__.py +++ b/mindone/diffusers/models/unets/__init__.py @@ -1,2 +1,4 @@ +from .unet_1d import UNet1DModel from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel +from .unet_stable_cascade import StableCascadeUNet diff --git a/mindone/diffusers/models/unets/unet_1d.py b/mindone/diffusers/models/unets/unet_1d.py new file mode 100644 index 0000000000..4c469e383d --- /dev/null +++ b/mindone/diffusers/models/unets/unet_1d.py @@ -0,0 +1,261 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import mindspore as ms +from mindspore import nn, ops + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from ..modeling_utils import ModelMixin +from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block + + +@dataclass +class UNet1DOutput(BaseOutput): + """ + The output of [`UNet1DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): + The hidden states output from the last layer of the model. + """ + + sample: ms.Tensor + + +class UNet1DModel(ModelMixin, ConfigMixin): + r""" + A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. + in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. + extra_in_channels (`int`, *optional*, defaults to 0): + Number of additional channels to be added to the input of the first down block. Useful for cases where the + input data has more channels than what the model was initially designed for. + time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. + freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip sin to cos for Fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`): + Tuple of block output channels. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet. + out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet. + act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks. + norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization. + layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block. + downsample_each_block (`int`, *optional*, defaults to `False`): + Experimental feature for using a UNet without upsampling. + """ + + @register_to_config + def __init__( + self, + sample_size: int = 65536, + sample_rate: Optional[int] = None, + in_channels: int = 2, + out_channels: int = 2, + extra_in_channels: int = 0, + time_embedding_type: str = "fourier", + flip_sin_to_cos: bool = True, + use_timestep_embedding: bool = False, + freq_shift: float = 0.0, + down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), + up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), + mid_block_type: Tuple[str] = "UNetMidBlock1D", + out_block_type: str = None, + block_out_channels: Tuple[int] = (32, 32, 64), + act_fn: str = None, + norm_num_groups: int = 8, + layers_per_block: int = 1, + downsample_each_block: bool = False, + ): + super().__init__() + self.sample_size = sample_size + + # time + if time_embedding_type == "fourier": + self.time_proj = GaussianFourierProjection( + embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = 2 * block_out_channels[0] + elif time_embedding_type == "positional": + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift + ) + timestep_input_dim = block_out_channels[0] + + if use_timestep_embedding: + time_embed_dim = block_out_channels[0] * 4 + self.time_mlp = TimestepEmbedding( + in_channels=timestep_input_dim, + time_embed_dim=time_embed_dim, + act_fn=act_fn, + out_dim=block_out_channels[0], + ) + + # down + down_blocks = [] + output_channel = in_channels + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + + if i == 0: + input_channel += extra_in_channels + + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_downsample=not is_final_block or downsample_each_block, + ) + down_blocks.append(down_block) + self.down_blocks = nn.CellList(down_blocks) + + # mid + self.mid_block = get_mid_block( + mid_block_type, + in_channels=block_out_channels[-1], + mid_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + embed_dim=block_out_channels[0], + num_layers=layers_per_block, + add_downsample=downsample_each_block, + ) + + # up + up_blocks = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + if out_block_type is None: + final_upsample_channels = out_channels + else: + final_upsample_channels = block_out_channels[0] + + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = ( + reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels + ) + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block, + in_channels=prev_output_channel, + out_channels=output_channel, + temb_channels=block_out_channels[0], + add_upsample=not is_final_block, + ) + up_blocks.append(up_block) + prev_output_channel = output_channel + self.up_blocks = nn.CellList(up_blocks) + + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.out_block = get_out_block( + out_block_type=out_block_type, + num_groups_out=num_groups_out, + embed_dim=block_out_channels[0], + out_channels=out_channels, + act_fn=act_fn, + fc_dim=block_out_channels[-1] // 4, + ) + + self.use_timestep_embedding = self.config.use_timestep_embedding + + def construct( + self, + sample: ms.Tensor, + timestep: Union[ms.Tensor, float, int], + return_dict: bool = False, + ) -> Union[UNet1DOutput, Tuple]: + r""" + The [`UNet1DModel`] forward method. + + Args: + sample (`ms.Tensor`): + The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`. + timestep (`ms.Tensor` or `float` or `int`): The number of timesteps to denoise an input. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_1d.UNet1DOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is + returned where the first element is the sample tensor. + """ + + # 1. time + timesteps = timestep + if not ops.is_tensor(timesteps): + timesteps = ms.tensor([timesteps], dtype=ms.int64) + elif ops.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None] + + timestep_embed = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + timestep_embed = timestep_embed.to(dtype=self.dtype) + if self.use_timestep_embedding: + timestep_embed = self.time_mlp(timestep_embed) + else: + timestep_embed = timestep_embed[..., None] + timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) + timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:])) + + # 2. down + down_block_res_samples = () + for downsample_block in self.down_blocks: + sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed) + down_block_res_samples += res_samples + + # 3. mid + if self.mid_block: + sample = self.mid_block(sample, timestep_embed) + + # 4. up + for i, upsample_block in enumerate(self.up_blocks): + res_samples = down_block_res_samples[-1:] + down_block_res_samples = down_block_res_samples[:-1] + sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) + + # 5. post-process + if self.out_block: + sample = self.out_block(sample, timestep_embed) + + if not return_dict: + return (sample,) + + return UNet1DOutput(sample=sample) diff --git a/mindone/diffusers/models/unets/unet_1d_blocks.py b/mindone/diffusers/models/unets/unet_1d_blocks.py new file mode 100644 index 0000000000..1c3de41e1c --- /dev/null +++ b/mindone/diffusers/models/unets/unet_1d_blocks.py @@ -0,0 +1,746 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Optional, Tuple, Union + +import mindspore as ms +from mindspore import nn, ops + +from ..activations import get_activation +from ..resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims + + +class DownResnetBlock1D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + conv_shortcut: bool = False, + temb_channels: int = 32, + groups: int = 32, + groups_out: Optional[int] = None, + non_linearity: Optional[str] = None, + time_embedding_norm: str = "default", + output_scale_factor: float = 1.0, + add_downsample: bool = True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.non_linearity = non_linearity + self.time_embedding_norm = time_embedding_norm + self.add_downsample = add_downsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.CellList(resnets) + + if self.non_linearity is not None: + self.nonlinearity = get_activation(non_linearity)() + + if self.add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) + + def construct(self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None) -> ms.Tensor: + output_states = () + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.non_linearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.add_downsample: + hidden_states = self.downsample(hidden_states) + + return hidden_states, output_states + + +class UpResnetBlock1D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + temb_channels: int = 32, + groups: int = 32, + groups_out: Optional[int] = None, + non_linearity: Optional[str] = None, + time_embedding_norm: str = "default", + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.non_linearity = non_linearity + self.time_embedding_norm = time_embedding_norm + self.add_upsample = add_upsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.CellList(resnets) + + if self.non_linearity is not None: + self.nonlinearity = get_activation(non_linearity)() + + if self.add_upsample: + self.upsample = Upsample1D(out_channels, use_conv_transpose=True) + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Optional[Tuple[ms.Tensor, ...]] = None, + temb: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + if res_hidden_states_tuple is not None: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = ops.cat((hidden_states, res_hidden_states), axis=1) + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.non_linearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.add_upsample: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class ValueFunctionMidBlock1D(nn.Cell): + def __init__(self, in_channels: int, out_channels: int, embed_dim: int): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + + self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) + self.down1 = Downsample1D(out_channels // 2, use_conv=True) + self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) + self.down2 = Downsample1D(out_channels // 4, use_conv=True) + + def construct(self, x: ms.Tensor, temb: Optional[ms.Tensor] = None) -> ms.Tensor: + x = self.res1(x, temb) + x = self.down1(x) + x = self.res2(x, temb) + x = self.down2(x) + return x + + +class MidResTemporalBlock1D(nn.Cell): + def __init__( + self, + in_channels: int, + out_channels: int, + embed_dim: int, + num_layers: int = 1, + add_downsample: bool = False, + add_upsample: bool = False, + non_linearity: Optional[str] = None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_downsample = add_downsample + self.add_upsample = add_upsample + self.non_linearity = non_linearity + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) + + self.resnets = nn.CellList(resnets) + + if self.non_linearity is not None: + self.nonlinearity = get_activation(non_linearity)() + + if self.add_upsample: + self.upsample = Downsample1D(out_channels, use_conv=True) + + if self.add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True) + + if self.add_upsample and self.add_downsample: + raise ValueError("Block cannot downsample and upsample") + + def construct(self, hidden_states: ms.Tensor, temb: ms.Tensor) -> ms.Tensor: + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.add_upsample: + hidden_states = self.upsample(hidden_states) + if self.add_downsample: + self.downsample = self.downsample(hidden_states) + + return hidden_states + + +class OutConv1DBlock(nn.Cell): + def __init__(self, num_groups_out: int, out_channels: int, embed_dim: int, act_fn: str): + super().__init__() + self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2, has_bias=True, pad_mode="pad") + self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) + self.final_conv1d_act = get_activation(act_fn)() + self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1, has_bias=True, pad_mode="valid") + + def construct(self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None) -> ms.Tensor: + hidden_states = self.final_conv1d_1(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_gn(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_act(hidden_states) + hidden_states = self.final_conv1d_2(hidden_states) + return hidden_states + + +class OutValueFunctionBlock(nn.Cell): + def __init__(self, fc_dim: int, embed_dim: int, act_fn: str = "mish"): + super().__init__() + self.final_block = nn.CellList( + [ + nn.Dense(fc_dim + embed_dim, fc_dim // 2), + get_activation(act_fn)(), + nn.Dense(fc_dim // 2, 1), + ] + ) + + def construct(self, hidden_states: ms.Tensor, temb: ms.Tensor) -> ms.Tensor: + hidden_states = hidden_states.view(hidden_states.shape[0], -1) + hidden_states = ops.cat((hidden_states, temb), axis=-1) + for layer in self.final_block: + hidden_states = layer(hidden_states) + + return hidden_states + + +_kernels = { + "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], + "cubic": [-0.01171875, -0.03515625, 0.11328125, 0.43359375, 0.43359375, 0.11328125, -0.03515625, -0.01171875], + "lanczos3": [ + 0.003689131001010537, + 0.015056144446134567, + -0.03399861603975296, + -0.066637322306633, + 0.13550527393817902, + 0.44638532400131226, + 0.44638532400131226, + 0.13550527393817902, + -0.066637322306633, + -0.03399861603975296, + 0.015056144446134567, + 0.003689131001010537, + ], +} + + +class Downsample1d(nn.Cell): + def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = ms.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + hidden_states = ops.pad(hidden_states, (self.pad,) * 2, self.pad_mode) + weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) + indices = ops.arange(hidden_states.shape[1]) + kernel = self.kernel.to(weight.dtype)[None, :].broadcast_to(hidden_states.shape[1], -1) + weight[indices, indices] = kernel + return ops.conv1d(hidden_states, weight, stride=2) + + +class Upsample1d(nn.Cell): + def __init__(self, kernel: str = "linear", pad_mode: str = "reflect"): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = ms.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer("kernel", kernel_1d) + + def construct(self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None) -> ms.Tensor: + hidden_states = ops.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) + indices = ops.arange(hidden_states.shape[1]) + kernel = self.kernel.to(weight.dtype)[None, :].broadcast_to(hidden_states.shape[1], -1) + weight[indices, indices] = kernel + return _conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) + + +class SelfAttention1d(nn.Cell): + def __init__(self, in_channels: int, n_head: int = 1, dropout_rate: float = 0.0): + super().__init__() + self.channels = in_channels + self.group_norm = nn.GroupNorm(1, num_channels=in_channels) + self.num_heads = n_head + + self.query = nn.Dense(self.channels, self.channels) + self.key = nn.Dense(self.channels, self.channels) + self.value = nn.Dense(self.channels, self.channels) + + self.proj_attn = nn.Dense(self.channels, self.channels, bias=True) + + self.dropout = nn.Dropout(p=dropout_rate) + + def transpose_for_scores(self, projection: ms.Tensor) -> ms.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + residual = hidden_states + batch, channel_dim, seq = hidden_states.shape + + hidden_states = self.group_norm(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + scale = 1 / math.sqrt(math.sqrt(key_states.shape[-1])) + + attention_scores = ops.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) + attention_probs = ops.softmax(attention_scores, axis=-1) + + # compute attention output + hidden_states = ops.matmul(attention_probs, value_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3) + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.dropout(hidden_states) + + output = hidden_states + residual + + return output + + +class ResConvBlock(nn.Cell): + def __init__(self, in_channels: int, mid_channels: int, out_channels: int, is_last: bool = False): + super().__init__() + self.is_last = is_last + self.has_conv_skip = in_channels != out_channels + + if self.has_conv_skip: + self.conv_skip = nn.Conv1d(in_channels, out_channels, 1, pad_mode="valid") + + self.conv_1 = nn.Conv1d(in_channels, mid_channels, 5, padding=2, has_bias=True, pad_mode="pad") + self.group_norm_1 = nn.GroupNorm(1, mid_channels) + self.gelu_1 = nn.GELU() + self.conv_2 = nn.Conv1d(mid_channels, out_channels, 5, padding=2, has_bias=True, pad_mode="pad") + + if not self.is_last: + self.group_norm_2 = nn.GroupNorm(1, out_channels) + self.gelu_2 = nn.GELU() + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + residual = self.conv_skip(hidden_states) if self.has_conv_skip else hidden_states + + hidden_states = self.conv_1(hidden_states) + hidden_states = self.group_norm_1(hidden_states) + hidden_states = self.gelu_1(hidden_states) + hidden_states = self.conv_2(hidden_states) + + if not self.is_last: + hidden_states = self.group_norm_2(hidden_states) + hidden_states = self.gelu_2(hidden_states) + + output = hidden_states + residual + return output + + +class UNetMidBlock1D(nn.Cell): + def __init__(self, mid_channels: int, in_channels: int, out_channels: Optional[int] = None): + super().__init__() + + out_channels = in_channels if out_channels is None else out_channels + + # there is always at least one resnet + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + self.up = Upsample1d(kernel="cubic") + + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + + def construct(self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None) -> ms.Tensor: + hidden_states = self.down(hidden_states) + for attn, resnet in zip(self.attentions, self.resnets): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class AttnDownBlock1D(nn.Cell): + def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + + def construct(self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None) -> ms.Tensor: + hidden_states = self.down(hidden_states) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + return hidden_states, (hidden_states,) + + +class DownBlock1D(nn.Cell): + def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + self.down = Downsample1d("cubic") + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.CellList(resnets) + + def construct(self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None) -> ms.Tensor: + hidden_states = self.down(hidden_states) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states, (hidden_states,) + + +class DownBlock1DNoSkip(nn.Cell): + def __init__(self, out_channels: int, in_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.CellList(resnets) + + def construct(self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None) -> ms.Tensor: + hidden_states = ops.cat([hidden_states, temb], axis=1) + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states, (hidden_states,) + + +class AttnUpBlock1D(nn.Cell): + def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = out_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + attentions = [ + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(mid_channels, mid_channels // 32), + SelfAttention1d(out_channels, out_channels // 32), + ] + + self.attentions = nn.CellList(attentions) + self.resnets = nn.CellList(resnets) + self.up = Upsample1d(kernel="cubic") + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states) + hidden_states = attn(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class UpBlock1D(nn.Cell): + def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels), + ] + + self.resnets = nn.CellList(resnets) + self.up = Upsample1d(kernel="cubic") + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + hidden_states = self.up(hidden_states) + + return hidden_states + + +class UpBlock1DNoSkip(nn.Cell): + def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None): + super().__init__() + mid_channels = in_channels if mid_channels is None else mid_channels + + resnets = [ + ResConvBlock(2 * in_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, mid_channels), + ResConvBlock(mid_channels, mid_channels, out_channels, is_last=True), + ] + + self.resnets = nn.CellList(resnets) + + def construct( + self, + hidden_states: ms.Tensor, + res_hidden_states_tuple: Tuple[ms.Tensor, ...], + temb: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = ops.cat([hidden_states, res_hidden_states], axis=1) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states + + +DownBlockType = Union[DownResnetBlock1D, DownBlock1D, AttnDownBlock1D, DownBlock1DNoSkip] +MidBlockType = Union[MidResTemporalBlock1D, ValueFunctionMidBlock1D, UNetMidBlock1D] +OutBlockType = Union[OutConv1DBlock, OutValueFunctionBlock] +UpBlockType = Union[UpResnetBlock1D, UpBlock1D, AttnUpBlock1D, UpBlock1DNoSkip] + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, +) -> DownBlockType: + if down_block_type == "DownResnetBlock1D": + return DownResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "DownBlock1D": + return DownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "AttnDownBlock1D": + return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "DownBlock1DNoSkip": + return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type: str, num_layers: int, in_channels: int, out_channels: int, temb_channels: int, add_upsample: bool +) -> UpBlockType: + if up_block_type == "UpResnetBlock1D": + return UpResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + ) + elif up_block_type == "UpBlock1D": + return UpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "AttnUpBlock1D": + return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "UpBlock1DNoSkip": + return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) + raise ValueError(f"{up_block_type} does not exist.") + + +def get_mid_block( + mid_block_type: str, + num_layers: int, + in_channels: int, + mid_channels: int, + out_channels: int, + embed_dim: int, + add_downsample: bool, +) -> MidBlockType: + if mid_block_type == "MidResTemporalBlock1D": + return MidResTemporalBlock1D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + embed_dim=embed_dim, + add_downsample=add_downsample, + ) + elif mid_block_type == "ValueFunctionMidBlock1D": + return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) + elif mid_block_type == "UNetMidBlock1D": + return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) + raise ValueError(f"{mid_block_type} does not exist.") + + +def get_out_block( + *, out_block_type: str, num_groups_out: int, embed_dim: int, out_channels: int, act_fn: str, fc_dim: int +) -> Optional[OutBlockType]: + if out_block_type == "OutConv1DBlock": + return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) + elif out_block_type == "ValueFunction": + return OutValueFunctionBlock(fc_dim, embed_dim, act_fn) + return None + + +def _conv_transpose1d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + # Equivalence of torch.nn.functional.conv_transpose1d + assert output_padding == 0, "Only support output_padding == 0 so far." + + if isinstance(stride, tuple): + stride = stride[0] + if isinstance(dilation, tuple): + dilation = dilation[0] + if isinstance(padding, tuple): + padding = padding[0] + + # InferShape manually + # Format adapted from https://pytorch.org/docs/stable/generated/torch.nn.functional.conv_transpose1d.html + batch_size, in_channels, iW = input.shape + _, out_channels_divide_groups, kW = weight.shape + + out_channels = out_channels_divide_groups * groups + # outW = (iW - 1) * stride - 2 * padding + dilation * (kW - 1) + 1 + + if bias is None: + op_conv_transpose1d = nn.Conv1dTranspose( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kW, + stride=stride, + pad_mode="pad", + padding=padding, + dilation=dilation, + weight_init=weight, + group=groups, + ) + else: + op_conv_transpose1d = nn.Conv1dTranspose( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kW, + stride=stride, + pad_mode="pad", + padding=padding, + dilation=dilation, + group=groups, + has_bias=True, + weight_init=weight, + bias_init=bias, + ) + + outputs = op_conv_transpose1d(input) + + return outputs diff --git a/mindone/diffusers/models/unets/unet_stable_cascade.py b/mindone/diffusers/models/unets/unet_stable_cascade.py new file mode 100644 index 0000000000..a4bdb3dfec --- /dev/null +++ b/mindone/diffusers/models/unets/unet_stable_cascade.py @@ -0,0 +1,619 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np + +import mindspore as ms +from mindspore import nn, ops +from mindspore.common.initializer import Constant, Normal, XavierNormal, initializer + +from ...configuration_utils import ConfigMixin, register_to_config + +# from ...loaders.unet import FromOriginalUNetMixin +from ...utils import BaseOutput +from ..attention_processor import Attention +from ..modeling_utils import ModelMixin +from ..normalization import LayerNorm + + +# Copied from diffusers.pipelines.wuerstchen.modeling_wuerstchen_common.WuerstchenLayerNorm with WuerstchenLayerNorm -> SDCascadeLayerNorm +class SDCascadeLayerNorm(LayerNorm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def construct(self, x): + x = x.permute(0, 2, 3, 1) + x = super().construct(x) + return x.permute(0, 3, 1, 2) + + +class SDCascadeTimestepBlock(nn.Cell): + def __init__(self, c, c_timestep, conds=[]): + super().__init__() + self.mapper = nn.Dense(c_timestep, c * 2) + self.conds = conds + for cname in conds: + setattr(self, f"mapper_{cname}", nn.Dense(c_timestep, c * 2)) + + def construct(self, x, t): + t = t.chunk(len(self.conds) + 1, axis=1) + a, b = self.mapper(t[0])[:, :, None, None].chunk(2, axis=1) + for i, c in enumerate(self.conds): + ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, axis=1) + a, b = a + ac, b + bc + return x * (1 + a) + b + + +class SDCascadeResBlock(nn.Cell): + def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): + super().__init__() + self.depthwise = nn.Conv2d( + c, c, kernel_size=kernel_size, padding=kernel_size // 2, group=c, pad_mode="pad", has_bias=True + ) + self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.channelwise = nn.SequentialCell( + nn.Dense(c + c_skip, c * 4), + nn.GELU(), + GlobalResponseNorm(c * 4), + nn.Dropout(p=dropout), + nn.Dense(c * 4, c), + ) + + def construct(self, x, x_skip=None): + x_res = x + x = self.norm(self.depthwise(x)) + if x_skip is not None: + x = ops.cat([x, x_skip], axis=1) + x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + x_res + + +# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 +class GlobalResponseNorm(nn.Cell): + def __init__(self, dim): + super().__init__() + self.gamma = ms.Parameter(ops.zeros((1, 1, 1, dim)), name="gamma") + self.beta = ms.Parameter(ops.zeros((1, 1, 1, dim)), name="beta") + + def construct(self, x): + agg_norm = ops.norm(x, ord=2, dim=(1, 2), keepdim=True).to(x.dtype) + stand_div_norm = agg_norm / (agg_norm.mean(axis=-1, keep_dims=True) + 1e-6) + return self.gamma * (x * stand_div_norm) + self.beta + x + + +class SDCascadeAttnBlock(nn.Cell): + def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): + super().__init__() + linear_cls = nn.Dense + + self.self_attn = self_attn + self.norm = SDCascadeLayerNorm(c, elementwise_affine=False, eps=1e-6) + self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True) + self.kv_mapper = nn.SequentialCell(nn.SiLU(), linear_cls(c_cond, c)) + + def construct(self, x, kv): + kv = self.kv_mapper(kv) + norm_x = self.norm(x) + if self.self_attn: + batch_size, channel, _, _ = x.shape + kv = ops.cat([norm_x.view(batch_size, channel, -1).swapaxes(1, 2), kv], axis=1) + x = x + self.attention(norm_x, encoder_hidden_states=kv) + return x + + +class UpDownBlock2d(nn.Cell): + def __init__(self, in_channels, out_channels, mode, enabled=True): + super().__init__() + if mode not in ["up", "down"]: + raise ValueError(f"{mode} not supported") + interpolation = ( + nn.Upsample(scale_factor=2 if mode == "up" else 0.5, mode="bilinear", align_corners=True) + if enabled + else nn.Identity() + ) + mapping = nn.Conv2d(in_channels, out_channels, kernel_size=1, has_bias=True, pad_mode="valid") + self.blocks = nn.CellList([interpolation, mapping] if mode == "up" else [mapping, interpolation]) + + def construct(self, x): + for block in self.blocks: + x = block(x) + return x + + +@dataclass +class StableCascadeUNetOutput(BaseOutput): + sample: ms.Tensor = None + + +# class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalUNetMixin): +class StableCascadeUNet(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 16, + out_channels: int = 16, + timestep_ratio_embedding_dim: int = 64, + patch_size: int = 1, + conditioning_dim: int = 2048, + block_out_channels: Tuple[int] = (2048, 2048), + num_attention_heads: Tuple[int] = (32, 32), + down_num_layers_per_block: Tuple[int] = (8, 24), + up_num_layers_per_block: Tuple[int] = (24, 8), + down_blocks_repeat_mappers: Optional[Tuple[int]] = ( + 1, + 1, + ), + up_blocks_repeat_mappers: Optional[Tuple[int]] = (1, 1), + block_types_per_layer: Tuple[Tuple[str]] = ( + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ), + clip_text_in_channels: Optional[int] = None, + clip_text_pooled_in_channels=1280, + clip_image_in_channels: Optional[int] = None, + clip_seq=4, + effnet_in_channels: Optional[int] = None, + pixel_mapper_in_channels: Optional[int] = None, + kernel_size=3, + dropout: Union[float, Tuple[float]] = (0.1, 0.1), + self_attn: Union[bool, Tuple[bool]] = True, + timestep_conditioning_type: Tuple[str] = ("sca", "crp"), + switch_level: Optional[Tuple[bool]] = None, + ): + """ + + Parameters: + in_channels (`int`, defaults to 16): + Number of channels in the input sample. + out_channels (`int`, defaults to 16): + Number of channels in the output sample. + timestep_ratio_embedding_dim (`int`, defaults to 64): + Dimension of the projected time embedding. + patch_size (`int`, defaults to 1): + Patch size to use for pixel unshuffling layer + conditioning_dim (`int`, defaults to 2048): + Dimension of the image and text conditional embedding. + block_out_channels (Tuple[int], defaults to (2048, 2048)): + Tuple of output channels for each block. + num_attention_heads (Tuple[int], defaults to (32, 32)): + Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have attention. + down_num_layers_per_block (Tuple[int], defaults to [8, 24]): + Number of layers in each down block. + up_num_layers_per_block (Tuple[int], defaults to [24, 8]): + Number of layers in each up block. + down_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]): + Number of 1x1 Convolutional layers to repeat in each down block. + up_blocks_repeat_mappers (Tuple[int], optional, defaults to [1, 1]): + Number of 1x1 Convolutional layers to repeat in each up block. + block_types_per_layer (Tuple[Tuple[str]], optional, + defaults to ( + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock") + ): + Block types used in each layer of the up/down blocks. + clip_text_in_channels (`int`, *optional*, defaults to `None`): + Number of input channels for CLIP based text conditioning. + clip_text_pooled_in_channels (`int`, *optional*, defaults to 1280): + Number of input channels for pooled CLIP text embeddings. + clip_image_in_channels (`int`, *optional*): + Number of input channels for CLIP based image conditioning. + clip_seq (`int`, *optional*, defaults to 4): + effnet_in_channels (`int`, *optional*, defaults to `None`): + Number of input channels for effnet conditioning. + pixel_mapper_in_channels (`int`, defaults to `None`): + Number of input channels for pixel mapper conditioning. + kernel_size (`int`, *optional*, defaults to 3): + Kernel size to use in the block convolutional layers. + dropout (Tuple[float], *optional*, defaults to (0.1, 0.1)): + Dropout to use per block. + self_attn (Union[bool, Tuple[bool]]): + Tuple of booleans that determine whether to use self attention in a block or not. + timestep_conditioning_type (Tuple[str], defaults to ("sca", "crp")): + Timestep conditioning type. + switch_level (Optional[Tuple[bool]], *optional*, defaults to `None`): + Tuple that indicates whether upsampling or downsampling should be applied in a block + """ + + super().__init__() + + if len(block_out_channels) != len(down_num_layers_per_block): + raise ValueError( + f"Number of elements in `down_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(up_num_layers_per_block): + raise ValueError( + f"Number of elements in `up_num_layers_per_block` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(down_blocks_repeat_mappers): + raise ValueError( + f"Number of elements in `down_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(up_blocks_repeat_mappers): + raise ValueError( + f"Number of elements in `up_blocks_repeat_mappers` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + elif len(block_out_channels) != len(block_types_per_layer): + raise ValueError( + f"Number of elements in `block_types_per_layer` must match the length of `block_out_channels`: {len(block_out_channels)}" + ) + + if isinstance(dropout, float): + dropout = (dropout,) * len(block_out_channels) + if isinstance(self_attn, bool): + self_attn = (self_attn,) * len(block_out_channels) + + # CONDITIONING + if effnet_in_channels is not None: + self.effnet_mapper = nn.SequentialCell( + nn.Conv2d( + effnet_in_channels, block_out_channels[0] * 4, kernel_size=1, has_bias=True, pad_mode="valid" + ), + nn.GELU(), + nn.Conv2d( + block_out_channels[0] * 4, block_out_channels[0], kernel_size=1, has_bias=True, pad_mode="valid" + ), + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + ) + else: + self.effnet_mapper = None + if pixel_mapper_in_channels is not None: + self.pixels_mapper = nn.SequentialCell( + nn.Conv2d( + pixel_mapper_in_channels, block_out_channels[0] * 4, kernel_size=1, has_bias=True, pad_mode="valid" + ), + nn.GELU(), + nn.Conv2d( + block_out_channels[0] * 4, block_out_channels[0], kernel_size=1, has_bias=True, pad_mode="valid" + ), + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + ) + else: + self.pixels_mapper = None + + self.clip_txt_pooled_mapper = nn.Dense(clip_text_pooled_in_channels, conditioning_dim * clip_seq) + if clip_text_in_channels is not None: + self.clip_txt_mapper = nn.Dense(clip_text_in_channels, conditioning_dim) + if clip_image_in_channels is not None: + self.clip_img_mapper = nn.Dense(clip_image_in_channels, conditioning_dim * clip_seq) + self.clip_norm = LayerNorm(conditioning_dim, elementwise_affine=False, eps=1e-6) + + self.embedding = nn.SequentialCell( + nn.PixelUnshuffle(patch_size), + nn.Conv2d( + in_channels * (patch_size**2), block_out_channels[0], kernel_size=1, has_bias=True, pad_mode="valid" + ), + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + ) + + def get_block(block_type, in_channels, nhead, c_skip=0, dropout=0, self_attn=True): + if block_type == "SDCascadeResBlock": + return SDCascadeResBlock(in_channels, c_skip, kernel_size=kernel_size, dropout=dropout) + elif block_type == "SDCascadeAttnBlock": + return SDCascadeAttnBlock(in_channels, conditioning_dim, nhead, self_attn=self_attn, dropout=dropout) + elif block_type == "SDCascadeTimestepBlock": + return SDCascadeTimestepBlock( + in_channels, timestep_ratio_embedding_dim, conds=timestep_conditioning_type + ) + else: + raise ValueError(f"Block type {block_type} not supported") + + # BLOCKS + # -- down blocks + down_blocks = [] + down_downscalers = [] + down_repeat_mappers = [] + for i in range(len(block_out_channels)): + if i > 0: + down_downscalers.append( + nn.SequentialCell( + SDCascadeLayerNorm(block_out_channels[i - 1], elementwise_affine=False, eps=1e-6), + UpDownBlock2d( + block_out_channels[i - 1], block_out_channels[i], mode="down", enabled=switch_level[i - 1] + ) + if switch_level is not None + else nn.Conv2d( + block_out_channels[i - 1], + block_out_channels[i], + kernel_size=2, + stride=2, + has_bias=True, + pad_mode="valid", + ), + ) + ) + else: + down_downscalers.append(nn.Identity()) + + down_block = [] + for _ in range(down_num_layers_per_block[i]): + for block_type in block_types_per_layer[i]: + block = get_block( + block_type, + block_out_channels[i], + num_attention_heads[i], + dropout=dropout[i], + self_attn=self_attn[i], + ) + down_block.append(block) + down_blocks.append(nn.CellList(down_block)) + + if down_blocks_repeat_mappers is not None: + block_repeat_mappers = [] + for _ in range(down_blocks_repeat_mappers[i] - 1): + block_repeat_mappers.append( + nn.Conv2d( + block_out_channels[i], block_out_channels[i], kernel_size=1, has_bias=True, pad_mode="valid" + ) + ) + down_repeat_mappers.append(nn.CellList(block_repeat_mappers)) + + self.down_blocks = nn.CellList(down_blocks) + self.down_downscalers = nn.CellList(down_downscalers) + self.down_repeat_mappers = nn.CellList(down_repeat_mappers) + + # -- up blocks + up_blocks = [] + up_upscalers = [] + up_repeat_mappers = [] + for i in reversed(range(len(block_out_channels))): + if i > 0: + up_upscalers.append( + nn.SequentialCell( + SDCascadeLayerNorm(block_out_channels[i], elementwise_affine=False, eps=1e-6), + UpDownBlock2d( + block_out_channels[i], block_out_channels[i - 1], mode="up", enabled=switch_level[i - 1] + ) + if switch_level is not None + else nn.Conv2dTranspose( + block_out_channels[i], + block_out_channels[i - 1], + kernel_size=2, + stride=2, + has_bias=True, + pad_mode="valid", + ), + ) + ) + else: + up_upscalers.append(nn.Identity()) + + up_block = [] + for j in range(up_num_layers_per_block[::-1][i]): + for k, block_type in enumerate(block_types_per_layer[i]): + c_skip = block_out_channels[i] if i < len(block_out_channels) - 1 and j == k == 0 else 0 + block = get_block( + block_type, + block_out_channels[i], + num_attention_heads[i], + c_skip=c_skip, + dropout=dropout[i], + self_attn=self_attn[i], + ) + up_block.append(block) + up_blocks.append(nn.CellList(up_block)) + + if up_blocks_repeat_mappers is not None: + block_repeat_mappers = [] + for _ in range(up_blocks_repeat_mappers[::-1][i] - 1): + block_repeat_mappers.append( + nn.Conv2d(block_out_channels[i], block_out_channels[i], kernel_size=1, has_bias=True) + ) + up_repeat_mappers.append(nn.CellList(block_repeat_mappers)) + + self.up_blocks = nn.CellList(up_blocks) + self.up_upscalers = nn.CellList(up_upscalers) + self.up_repeat_mappers = nn.CellList(up_repeat_mappers) + + # OUTPUT + self.clf = nn.SequentialCell( + SDCascadeLayerNorm(block_out_channels[0], elementwise_affine=False, eps=1e-6), + nn.Conv2d( + block_out_channels[0], out_channels * (patch_size**2), kernel_size=1, has_bias=True, pad_mode="valid" + ), + nn.PixelShuffle(patch_size), + ) + + self._gradient_checkpointing = False + + # def _set_gradient_checkpointing(self, value=False): + # self._gradient_checkpointing = value + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Dense)): + m.weight.set_data(initializer(XavierNormal(), m.weight.shape, m.weight.dtype)) + if m.bias is not None: + m.bias.set_data(initializer(Constant(0), m.bias.shape, m.bias.dtype)) + + self.clip_txt_pooled_mapper.weight.set_data( + initializer( + Normal(sigma=0.02), self.clip_txt_pooled_mapper.weight.shape, self.clip_txt_pooled_mapper.weight.dtype + ) + ) + self.clip_txt_mapper.weight.set_data( + initializer(Normal(sigma=0.02), self.clip_txt_mapper.weight.shape, self.clip_txt_mapper.weight.dtype) + ) if hasattr(self, "clip_txt_mapper") else None + self.clip_img_mapper.weight.set_data( + initializer(Normal(sigma=0.02), self.clip_img_mapper.weight.shape, self.clip_img_mapper.weight.dtype) + ) if hasattr(self, "clip_img_mapper") else None + + if hasattr(self, "effnet_mapper"): + self.effnet_mapper[0].weight.set_data( + initializer(Normal(sigma=0.02), self.effnet_mapper[0].weight.shape, self.effnet_mapper[0].weight.dtype) + ) # conditionings + self.effnet_mapper[2].weight.set_data( + initializer(Normal(sigma=0.02), self.effnet_mapper[2].weight.shape, self.effnet_mapper[2].weight.dtype) + ) # conditionings + + if hasattr(self, "pixels_mapper"): + self.pixels_mapper[0].weight.set_data( + initializer(Normal(sigma=0.02), self.pixels_mapper[0].weight.shape, self.pixels_mapper[0].weight.dtype) + ) # conditionings + self.pixels_mapper[2].weight.set_data( + initializer(Normal(sigma=0.02), self.pixels_mapper[2].weight.shape, self.pixels_mapper[2].weight.dtype) + ) # conditionings + + self.embedding[1].weight.set_data( + initializer(XavierNormal(gain=0.02), self.embedding[1].weight.shape, self.embedding[1].weight.dtype) + ) # inputs + self.clf[1].weight.set_data(initializer(0, self.clf[1].weight.shape, self.clf[1].weight.dtype)) # outputs + + # blocks + for level_block in self.down_blocks + self.up_blocks: + for block in level_block: + if isinstance(block, SDCascadeResBlock): + block.channelwise[-1].weight *= np.sqrt(1 / sum(self.config.blocks[0])) + elif isinstance(block, SDCascadeTimestepBlock): + nn.init.constant_(block.mapper.weight, 0) + + def get_timestep_ratio_embedding(self, timestep_ratio, max_positions=10000): + r = timestep_ratio * max_positions + half_dim = self.config["timestep_ratio_embedding_dim"] // 2 + + emb = math.log(max_positions) / (half_dim - 1) + emb = ops.arange(half_dim).float().mul(-emb).exp() + emb = r[:, None] * emb[None, :] + emb = ops.cat([emb.sin(), emb.cos()], axis=1) + + if self.config["timestep_ratio_embedding_dim"] % 2 == 1: # zero pad + emb = ops.pad(emb, (0, 1), mode="constant") + + return emb.to(dtype=r.dtype) + + def get_clip_embeddings(self, clip_txt_pooled, clip_txt=None, clip_img=None): + if len(clip_txt_pooled.shape) == 2: + clip_txt_pool = clip_txt_pooled.unsqueeze(1) + clip_txt_pool = self.clip_txt_pooled_mapper(clip_txt_pooled).view( + clip_txt_pooled.shape[0], clip_txt_pooled.shape[1] * self.config["clip_seq"], -1 + ) + if clip_txt is not None and clip_img is not None: + clip_txt = self.clip_txt_mapper(clip_txt) + if len(clip_img.shape) == 2: + clip_img = clip_img.unsqueeze(1) + clip_img = self.clip_img_mapper(clip_img).view( + clip_img.shape[0], clip_img.shape[1] * self.config["clip_seq"], -1 + ) + clip = ops.cat([clip_txt, clip_txt_pool, clip_img], axis=1) + else: + clip = clip_txt_pool + return self.clip_norm(clip) + + @property + def gradient_checkpointing(self): + return self._gradient_checkpointing + + @gradient_checkpointing.setter + def gradient_checkpointing(self, value): + self._gradient_checkpointing = value + # we exclude 0-th resnet following huggingface/diffusers. HF does this just for simplicity in forward? + for block in self.down_blocks: + block._recompute(value) + for block in self.up_blocks: + block._recompute(value) + + def _down_encode(self, x, r_embed, clip): + level_outputs = [] + block_group = list(zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)) + + for down_block, downscaler, repmap in block_group: + x = downscaler(x) + for i in range(len(repmap) + 1): + for block in down_block: + if isinstance(block, SDCascadeResBlock): + x = block(x) + elif isinstance(block, SDCascadeAttnBlock): + x = block(x, clip) + elif isinstance(block, SDCascadeTimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + if i < len(repmap): + x = repmap[i](x) + level_outputs.insert(0, x) + return level_outputs + + def _up_decode(self, level_outputs, r_embed, clip): + x = level_outputs[0] + block_group = list(zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)) + + for i, (up_block, upscaler, repmap) in enumerate(block_group): + for j in range(len(repmap) + 1): + for k, block in enumerate(up_block): + if isinstance(block, SDCascadeResBlock): + skip = level_outputs[i] if k == 0 and i > 0 else None + if skip is not None and (x.shape[-1] != skip.shape[-1] or x.shape[-2] != skip.shape[-2]): + x = ops.interpolate(x.float(), skip.shape[-2:], mode="bilinear", align_corners=True) + x = block(x, skip) + elif isinstance(block, SDCascadeAttnBlock): + x = block(x, clip) + elif isinstance(block, SDCascadeTimestepBlock): + x = block(x, r_embed) + else: + x = block(x) + if j < len(repmap): + x = repmap[j](x) + x = upscaler(x) + return x + + def construct( + self, + sample, + timestep_ratio, + clip_text_pooled, + clip_text=None, + clip_img=None, + effnet=None, + pixels=None, + sca=None, + crp=None, + return_dict=False, + ): + if pixels is None: + pixels = sample.new_zeros((sample.shape[0], 3, 8, 8)) + + # Process the conditioning embeddings + timestep_ratio_embed = self.get_timestep_ratio_embedding(timestep_ratio) + for c in self.config["timestep_conditioning_type"]: + if c == "sca": + cond = sca + elif c == "crp": + cond = crp + else: + cond = None + t_cond = cond or ops.zeros_like(timestep_ratio) + timestep_ratio_embed = ops.cat([timestep_ratio_embed, self.get_timestep_ratio_embedding(t_cond)], axis=1) + clip = self.get_clip_embeddings(clip_txt_pooled=clip_text_pooled, clip_txt=clip_text, clip_img=clip_img) + + # Model Blocks + x = self.embedding(sample) + if self.effnet_mapper is not None and effnet is not None: + x = x + self.effnet_mapper(ops.interpolate(effnet, size=x.shape[-2:], mode="bilinear", align_corners=True)) + if self.pixels_mapper is not None: + x = x + ops.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode="bilinear", align_corners=True) + level_outputs = self._down_encode(x, timestep_ratio_embed, clip) + x = self._up_decode(level_outputs, timestep_ratio_embed, clip) + sample = self.clf(x) + + if not return_dict: + return (sample,) + return StableCascadeUNetOutput(sample=sample) diff --git a/tests/diffusers/models/test_layers.py b/tests/diffusers/models/test_layers.py index effe25835a..a3ed2f1eb1 100644 --- a/tests/diffusers/models/test_layers.py +++ b/tests/diffusers/models/test_layers.py @@ -47,8 +47,10 @@ def get_pt2ms_mappings(m): mappings = {} # pt_param_name: (ms_param_name, pt_param_to_ms_param_func) for name, cell in m.cells_and_names(): - if isinstance(cell, nn.Conv1d): - mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ops.expand_dims(x, axis=-2) + if isinstance(cell, (nn.Conv1d, nn.Conv1dTranspose)): + mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ms.Parameter( + ops.expand_dims(x, axis=-2), name=x.name + ) elif isinstance(cell, nn.Embedding): mappings[f"{name}.weight"] = f"{name}.embedding_table", lambda x: x elif isinstance(cell, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)): diff --git a/tests/diffusers/models/test_layers_cases.py b/tests/diffusers/models/test_layers_cases.py index 4b0bcf1934..091543c5ef 100644 --- a/tests/diffusers/models/test_layers_cases.py +++ b/tests/diffusers/models/test_layers_cases.py @@ -426,7 +426,165 @@ ] +UNET1D_CASES = [ + [ + "UNet1DModel", + "diffusers.models.unets.unet_1d.UNet1DModel", + "mindone.diffusers.models.unets.unet_1d.UNet1DModel", + (), + dict( + block_out_channels=(32, 64, 128, 256), + in_channels=14, + out_channels=14, + time_embedding_type="positional", + use_timestep_embedding=True, + flip_sin_to_cos=False, + freq_shift=1.0, + out_block_type="OutConv1DBlock", + mid_block_type="MidResTemporalBlock1D", + down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), + up_block_types=("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"), + act_fn="swish", + ), + (), + { + "sample": np.random.randn(4, 14, 16).astype(np.float32), + "timestep": np.array([10] * 4, dtype=np.int64), + "return_dict": False, + }, + ], + [ + "UNetRLModel", + "diffusers.models.unets.unet_1d.UNet1DModel", + "mindone.diffusers.models.unets.unet_1d.UNet1DModel", + (), + dict( + in_channels=14, + out_channels=14, + down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), + up_block_types=(), + out_block_type="ValueFunction", + mid_block_type="ValueFunctionMidBlock1D", + block_out_channels=(32, 64, 128, 256), + layers_per_block=1, + downsample_each_block=True, + use_timestep_embedding=True, + freq_shift=1.0, + flip_sin_to_cos=False, + time_embedding_type="positional", + act_fn="mish", + ), + (), + { + "sample": np.random.randn(4, 14, 16).astype(np.float32), + "timestep": np.array([10] * 4, dtype=np.int64), + "return_dict": False, + }, + ], +] + + +UNETSTABLECASCADE_CASES = [ + [ + "UNetStableCascadeModel_prior", + "diffusers.models.unets.unet_stable_cascade.StableCascadeUNet", + "mindone.diffusers.models.unets.unet_stable_cascade.StableCascadeUNet", + (), + dict( + block_out_channels=(2048, 2048), + block_types_per_layer=( + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ), + clip_image_in_channels=768, + clip_seq=4, + clip_text_in_channels=1280, + clip_text_pooled_in_channels=1280, + conditioning_dim=2048, + down_blocks_repeat_mappers=(1, 1), + down_num_layers_per_block=(8, 24), + dropout=(0.1, 0.1), + effnet_in_channels=None, + in_channels=16, + kernel_size=3, + num_attention_heads=(32, 32), + out_channels=16, + patch_size=1, + pixel_mapper_in_channels=None, + self_attn=True, + switch_level=(False,), + timestep_conditioning_type=("sca", "crp"), + timestep_ratio_embedding_dim=64, + up_blocks_repeat_mappers=(1, 1), + up_num_layers_per_block=(24, 8), + ), + (), + { + "sample": np.random.randn(1, 16, 24, 24).astype(np.float32), + "timestep_ratio": np.array([1], dtype=np.float32), + "clip_text_pooled": np.random.randn(1, 1, 1280).astype(np.float32), + "clip_text": np.random.randn(1, 77, 1280).astype(np.float32), + "clip_img": np.random.randn(1, 1, 768).astype(np.float32), + "pixels": np.random.randn(1, 3, 8, 8).astype(np.float32), + "return_dict": False, + }, + ], + [ + "UNetStableCascadeModel_decoder", + "diffusers.models.unets.unet_stable_cascade.StableCascadeUNet", + "mindone.diffusers.models.unets.unet_stable_cascade.StableCascadeUNet", + (), + dict( + block_out_channels=(320, 640, 1280, 1280), + block_types_per_layer=( + ("SDCascadeResBlock", "SDCascadeTimestepBlock"), + ("SDCascadeResBlock", "SDCascadeTimestepBlock"), + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), + ), + clip_image_in_channels=None, + clip_seq=4, + clip_text_in_channels=None, + clip_text_pooled_in_channels=1280, + conditioning_dim=1280, + down_blocks_repeat_mappers=(1, 1, 1, 1), + down_num_layers_per_block=(2, 6, 28, 6), + dropout=(0, 0, 0.1, 0.1), + effnet_in_channels=16, + in_channels=4, + kernel_size=3, + num_attention_heads=(0, 0, 20, 20), + out_channels=4, + patch_size=2, + pixel_mapper_in_channels=3, + self_attn=True, + switch_level=None, + timestep_conditioning_type=("sca",), + timestep_ratio_embedding_dim=64, + up_blocks_repeat_mappers=(3, 3, 2, 2), + up_num_layers_per_block=(6, 28, 6, 2), + ), + (), + { + "sample": np.random.randn(1, 4, 256, 256).astype(np.float32), + "timestep_ratio": np.array([1], dtype=np.float32), + "clip_text_pooled": np.random.randn(1, 1, 1280).astype(np.float32), + "clip_text": np.random.randn(1, 77, 1280).astype(np.float32), + "pixels": np.random.randn(1, 3, 8, 8).astype(np.float32), + "return_dict": False, + }, + ], +] + + # CONTROL_NET_CASES: outputs format isn't same with others ALL_CASES = ( - NORMALIZATION_CASES + EMBEDDINGS_CASES + UPSAMPLE2D_CASES + DOWNSAMPLE2D_CASES + RESNET_CASES + T2I_ADAPTER_CASES + NORMALIZATION_CASES + + EMBEDDINGS_CASES + + UPSAMPLE2D_CASES + + DOWNSAMPLE2D_CASES + + RESNET_CASES + + T2I_ADAPTER_CASES + + UNET1D_CASES + + UNETSTABLECASCADE_CASES )