Skip to content

Commit

Permalink
feat(diffusers/models/unets): add unet1d & StableCascade (mindspore-l…
Browse files Browse the repository at this point in the history
  • Loading branch information
Cui-yshoho authored Jun 19, 2024
1 parent 7e2290e commit 6271bc5
Show file tree
Hide file tree
Showing 10 changed files with 1,821 additions and 8 deletions.
13 changes: 12 additions & 1 deletion mindone/diffusers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,20 @@ Most base, utility and mixin class are available.
- [ ] StableDiffusionPipeline

### Model

#### AutoEncoders

- [x] AutoencoderKL
- [x] Transformer2DModel

#### UNets

- [x] UNet1DModel
- [x] UNet2DConditionModel
- [x] StableCascadeUNet

#### Transformers

- [x] Transformer2DModel

### Scheduler
- [x] DDIMScheduler/DDPMScheduler/...(30)
Expand Down
12 changes: 11 additions & 1 deletion mindone/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
"AutoencoderKL",
"ModelMixin",
"SD3Transformer2DModel",
"UNet1DModel",
"UNet2DConditionModel",
"UNet2DModel",
"StableCascadeUNet",
],
"optimization": [
"get_constant_schedule",
Expand Down Expand Up @@ -77,7 +79,15 @@

if TYPE_CHECKING:
from .configuration_utils import ConfigMixin
from .models import AutoencoderKL, ModelMixin, SD3Transformer2DModel, UNet2DConditionModel, UNet2DModel
from .models import (
AutoencoderKL,
ModelMixin,
SD3Transformer2DModel,
StableCascadeUNet,
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,
)
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
Expand Down
4 changes: 3 additions & 1 deletion mindone/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
"modeling_utils": ["ModelMixin"],
"transformers.transformer_2d": ["Transformer2DModel"],
"transformers.transformer_sd3": ["SD3Transformer2DModel"],
"unets.unet_1d": ["UNet1DModel"],
"unets.unet_2d": ["UNet2DModel"],
"unets.unet_2d_condition": ["UNet2DConditionModel"],
"unets.unet_stable_cascade": ["StableCascadeUNet"],
}

if TYPE_CHECKING:
Expand All @@ -35,7 +37,7 @@
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .transformers import SD3Transformer2DModel, Transformer2DModel
from .unets import UNet2DConditionModel, UNet2DModel
from .unets import StableCascadeUNet, UNet1DModel, UNet2DConditionModel, UNet2DModel

else:
import sys
Expand Down
6 changes: 4 additions & 2 deletions mindone/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@
def _get_pt2ms_mappings(m):
mappings = {} # pt_param_name: (ms_param_name, pt_param_to_ms_param_func)
for name, cell in m.cells_and_names():
if isinstance(cell, nn.Conv1d):
mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ops.expand_dims(x, axis=-2)
if isinstance(cell, (nn.Conv1d, nn.Conv1dTranspose)):
mappings[f"{name}.weight"] = f"{name}.weight", lambda x: ms.Parameter(
ops.expand_dims(x, axis=-2), name=x.name
)
elif isinstance(cell, nn.Embedding):
mappings[f"{name}.weight"] = f"{name}.embedding_table", lambda x: x
elif isinstance(cell, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
Expand Down
2 changes: 2 additions & 0 deletions mindone/diffusers/models/unets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_stable_cascade import StableCascadeUNet
261 changes: 261 additions & 0 deletions mindone/diffusers/models/unets/unet_1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Optional, Tuple, Union

import mindspore as ms
from mindspore import nn, ops

from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block


@dataclass
class UNet1DOutput(BaseOutput):
"""
The output of [`UNet1DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
The hidden states output from the last layer of the model.
"""

sample: ms.Tensor


class UNet1DModel(ModelMixin, ConfigMixin):
r"""
A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime.
in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 2): Number of channels in the output.
extra_in_channels (`int`, *optional*, defaults to 0):
Number of additional channels to be added to the input of the first down block. Useful for cases where the
input data has more channels than what the model was initially designed for.
time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use.
freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip sin to cos for Fourier time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`):
Tuple of block output channels.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet.
out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet.
act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks.
norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization.
layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block.
downsample_each_block (`int`, *optional*, defaults to `False`):
Experimental feature for using a UNet without upsampling.
"""

@register_to_config
def __init__(
self,
sample_size: int = 65536,
sample_rate: Optional[int] = None,
in_channels: int = 2,
out_channels: int = 2,
extra_in_channels: int = 0,
time_embedding_type: str = "fourier",
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
freq_shift: float = 0.0,
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
mid_block_type: Tuple[str] = "UNetMidBlock1D",
out_block_type: str = None,
block_out_channels: Tuple[int] = (32, 32, 64),
act_fn: str = None,
norm_num_groups: int = 8,
layers_per_block: int = 1,
downsample_each_block: bool = False,
):
super().__init__()
self.sample_size = sample_size

# time
if time_embedding_type == "fourier":
self.time_proj = GaussianFourierProjection(
embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = 2 * block_out_channels[0]
elif time_embedding_type == "positional":
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
)
timestep_input_dim = block_out_channels[0]

if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=timestep_input_dim,
time_embed_dim=time_embed_dim,
act_fn=act_fn,
out_dim=block_out_channels[0],
)

# down
down_blocks = []
output_channel = in_channels
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]

if i == 0:
input_channel += extra_in_channels

is_final_block = i == len(block_out_channels) - 1

down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=block_out_channels[0],
add_downsample=not is_final_block or downsample_each_block,
)
down_blocks.append(down_block)
self.down_blocks = nn.CellList(down_blocks)

# mid
self.mid_block = get_mid_block(
mid_block_type,
in_channels=block_out_channels[-1],
mid_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
embed_dim=block_out_channels[0],
num_layers=layers_per_block,
add_downsample=downsample_each_block,
)

# up
up_blocks = []
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
if out_block_type is None:
final_upsample_channels = out_channels
else:
final_upsample_channels = block_out_channels[0]

for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel
output_channel = (
reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels
)

is_final_block = i == len(block_out_channels) - 1

up_block = get_up_block(
up_block_type,
num_layers=layers_per_block,
in_channels=prev_output_channel,
out_channels=output_channel,
temb_channels=block_out_channels[0],
add_upsample=not is_final_block,
)
up_blocks.append(up_block)
prev_output_channel = output_channel
self.up_blocks = nn.CellList(up_blocks)

# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
self.out_block = get_out_block(
out_block_type=out_block_type,
num_groups_out=num_groups_out,
embed_dim=block_out_channels[0],
out_channels=out_channels,
act_fn=act_fn,
fc_dim=block_out_channels[-1] // 4,
)

self.use_timestep_embedding = self.config.use_timestep_embedding

def construct(
self,
sample: ms.Tensor,
timestep: Union[ms.Tensor, float, int],
return_dict: bool = False,
) -> Union[UNet1DOutput, Tuple]:
r"""
The [`UNet1DModel`] forward method.
Args:
sample (`ms.Tensor`):
The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
timestep (`ms.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
return_dict (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple.
Returns:
[`~models.unet_1d.UNet1DOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is
returned where the first element is the sample tensor.
"""

# 1. time
timesteps = timestep
if not ops.is_tensor(timesteps):
timesteps = ms.tensor([timesteps], dtype=ms.int64)
elif ops.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None]

timestep_embed = self.time_proj(timesteps)

# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
timestep_embed = timestep_embed.to(dtype=self.dtype)
if self.use_timestep_embedding:
timestep_embed = self.time_mlp(timestep_embed)
else:
timestep_embed = timestep_embed[..., None]
timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
timestep_embed = timestep_embed.broadcast_to((sample.shape[:1] + timestep_embed.shape[1:]))

# 2. down
down_block_res_samples = ()
for downsample_block in self.down_blocks:
sample, res_samples = downsample_block(hidden_states=sample, temb=timestep_embed)
down_block_res_samples += res_samples

# 3. mid
if self.mid_block:
sample = self.mid_block(sample, timestep_embed)

# 4. up
for i, upsample_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-1:]
down_block_res_samples = down_block_res_samples[:-1]
sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed)

# 5. post-process
if self.out_block:
sample = self.out_block(sample, timestep_embed)

if not return_dict:
return (sample,)

return UNet1DOutput(sample=sample)
Loading

0 comments on commit 6271bc5

Please sign in to comment.