Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hpcai OpenSora v1.2 - 3D VAE inference #560

Merged
merged 45 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
269b8b6
add vae 3d enc-dec
SamitHuang Jun 18, 2024
c4de776
update test
SamitHuang Jun 19, 2024
37f74a8
dev save
SamitHuang Jun 19, 2024
5dff480
testing
SamitHuang Jun 19, 2024
c6122a8
spatial vae test pass
SamitHuang Jun 19, 2024
2edee15
fix
SamitHuang Jun 19, 2024
dfbfdf4
Merge branch 'cai_os1.2' of github.com:SamitHuang/mindone into cai_os1.2
SamitHuang Jun 19, 2024
3faeb25
add vae param list
SamitHuang Jun 19, 2024
3f9ee09
fix name order
SamitHuang Jun 19, 2024
cce6cb5
Merge branch 'cai_os1.2' of github.com:SamitHuang/mindone into cai_os1.2
SamitHuang Jun 19, 2024
50dbf03
add shape
SamitHuang Jun 19, 2024
9e7e582
add shape
SamitHuang Jun 19, 2024
2165403
Merge branch 'cai_os1.2' of github.com:SamitHuang/mindone into cai_os1.2
SamitHuang Jun 19, 2024
7ab7508
order pnames
SamitHuang Jun 19, 2024
bf57f53
ordered temporal pnames
SamitHuang Jun 19, 2024
733fa68
vae 3d recons ok
SamitHuang Jun 19, 2024
3a32e29
update docs
SamitHuang Jun 20, 2024
5eb4c6e
add test scripts
SamitHuang Jun 20, 2024
4241c2d
add convert script
SamitHuang Jun 20, 2024
9d5cc08
adapt to 910b
SamitHuang Jun 20, 2024
543fbf6
support ms2.3 5d GN
SamitHuang Jun 20, 2024
1fb8088
rm test files
SamitHuang Jun 20, 2024
614884f
fix format
SamitHuang Jun 20, 2024
5e5d3a9
debug infer
SamitHuang Jun 20, 2024
74e6969
add sample t2v yaml
SamitHuang Jun 20, 2024
0395b89
fix i2v
SamitHuang Jun 20, 2024
314d65f
Merge branch 'cai_os1.2' of https://github.com/samithuang/mindone int…
SamitHuang Jun 20, 2024
8429baa
update comment
SamitHuang Jun 20, 2024
ec1e5bf
fix format
SamitHuang Jun 20, 2024
420d935
rm tmp test
SamitHuang Jun 20, 2024
5f5ee1b
fix docs
SamitHuang Jun 20, 2024
be55c84
fix var name
SamitHuang Jun 21, 2024
b671293
fix latent shape compute
SamitHuang Jun 21, 2024
2756d94
Merge branch 'pr_vae1.2' of https://github.com/samithuang/mindone int…
SamitHuang Jun 21, 2024
d94ec5e
add info
SamitHuang Jun 21, 2024
6c3972b
fix image enc/dec
SamitHuang Jun 28, 2024
44b5c56
fix format
SamitHuang Jun 28, 2024
cd547ed
adapt new vae in training
SamitHuang Jun 28, 2024
ba17c00
fix dtype
SamitHuang Jun 28, 2024
0a067f9
pad bf16 fixed by cast to fp16
SamitHuang Jul 3, 2024
1ea897e
fix ops.pad bf16 with fp32 cast
SamitHuang Jul 3, 2024
32ed145
Merge branch 'master' into pr_vae1.2
SamitHuang Jul 4, 2024
7bb2697
replace pad with concat
SamitHuang Jul 5, 2024
c42096c
fix conflict
SamitHuang Jul 5, 2024
5507079
replace pad_at_dim with concat for bf16
SamitHuang Jul 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
model_version: v1.2
ckpt_path: "models/opensora_v1.2_stage3.ckpt"
t5_model_dir: "models/t5-v1_1-xxl/"

vae_model_type: "OpenSoraVAE_V1_2"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
vae_model_type: "OpenSoraVAE_V1_2"
vae_type: OpenSoraVAE_V1_2

vae_checkpoint: "models/OpenSora-VAE-v1.2/model.ckpt"
vae_micro_batch_size: 4
vae_micro_frame_size: 17

image_size: [ 240, 320 ]
num_frames: 24
frame_interval: 3
fps: 24
enable_flash_attention: True
model_max_length: 200
dtype: "fp16"
batch_size: 1

# sampling
sampling_steps: 100
guidance_scale: 7.0
guidance_channels: 3
seed: 42
ddim_sampling: False

loop: 1
condition_frame_length: 4

captions:
- "Snow falling over multiple houses and trees on winter landscape against night sky. christmas festivity and celebration concept"
6 changes: 6 additions & 0 deletions examples/opensora_hpcai/docs/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ Prepare the model checkpoints of T5, VAE, and STDiT and put them under `models/`
Convert to ms checkpoint: `python tools/convert_pt2ms.py --src /path/to/PixArt-XL-2-512x512.pth --target models/PixArt-XL-2-512x512.ckpt`
It will be used for better model initialization.


### OpenSora v1.2

- 3D VAE:
`python tools/convert_vae_3d.py --src path/to/OpenSora-VAE-v1.2/model.safetensors --target models/OpenSora-VAE-v1.2/model.ckpt`

## Inference

### Text-to-Video
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def __init__(
self.exp = ops.Exp()
self.stdnormal = ops.StandardNormal()

def init_from_ckpt(self, path, ignore_keys=list(), remove_prefix=["first_stage_model.", "autoencoder."]):
def init_from_ckpt(
self, path, ignore_keys=list(), remove_prefix=["first_stage_model.", "autoencoder.", "spatial_vae.module."]
):
# TODO: support auto download pretrained checkpoints
sd = ms.load_checkpoint(path)
keys = list(sd.keys())
Expand All @@ -55,7 +57,9 @@ def init_from_ckpt(self, path, ignore_keys=list(), remove_prefix=["first_stage_m
is_vae_param = True
if not is_vae_param:
sd.pop(pname)
ms.load_param_into_net(self, sd, strict_load=False)
pu, cu = ms.load_param_into_net(self, sd, strict_load=False)
print(f"Net param not loaded : {pu}")
print(f"Checkpoint param not loaded : {cu}")
print(f"Restored from {path}")

def _encode(self, x):
Expand Down
304 changes: 303 additions & 1 deletion examples/opensora_hpcai/opensora/models/vae/vae.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
import os

from transformers import PretrainedConfig

import mindspore as ms
from mindspore import ops
from mindspore import nn, ops

from .autoencoder_kl import AutoencoderKL as AutoencoderKL_SD
from .vae_temporal import VAE_Temporal_SD # noqa: F401

__all__ = ["AutoencoderKL"]

Expand Down Expand Up @@ -49,3 +52,302 @@ def encode_with_moments_output(self, x):
std = self.exp(0.5 * logvar)

return mean, std


# -------------------------------- OpenSora v1.2 Begin ------------------------------------ #

SDXL_CONFIG = SD_CONFIG.copy()
SDXL_CONFIG.update({"resolution": 512})


class VideoAutoencoderKL(nn.Cell):
"""
Spatial VAE
"""

def __init__(
self,
config=SDXL_CONFIG,
ckpt_path=None,
micro_batch_size=None,
):
super().__init__()

self.module = AutoencoderKL_SD(
ddconfig=config,
embed_dim=config["z_channels"],
ckpt_path=ckpt_path,
)

self.out_channels = config["z_channels"] # self.module.config.latent_channels
self.patch_size = (1, 8, 8)
self.micro_batch_size = micro_batch_size

# FIXME: "scaling_factor": 0.13025 is set in
# https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/blob/main/vae/config.json.
# This is a mistake made during the training of OpenSora v1.2.
# To re-use the trained model, we need to keep this mistake.
# For training, we should refine to 0.13025.
self.scale_factor = 0.18215

@staticmethod
def rearrange_in(x):
B, C, T, H, W = x.shape
# (b c t h w) -> (b t c h w)
x = ops.transpose(x, (0, 2, 1, 3, 4))
x = ops.reshape(x, (B * T, C, H, W))

return x

@staticmethod
def rearrange_out(x, B):
# x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
BT, C, H, W = x.shape
T = BT // B
x = ops.reshape(x, (B, T, C, H, W))
x = ops.transpose(x, (0, 2, 1, 3, 4))

return x

def encode(self, x):
# x: (B, C, T, H, W)
# NOTE: remind to use stop gradient when invoke it
# is_video = (x.ndim == 5)

B = x.shape[0]
# x = rearrange(x, "B C T H W -> (B T) C H W")
x = self.rearrange_in(x)

if self.micro_batch_size is None:
# x = self.module.encode(x).latent_dist.sample().mul_(0.18215)
x = self.module.encode(x) * self.scale_factor
else:
bs = self.micro_batch_size
x_out = []
# FIXME: supported in graph mode? or use split
for i in range(0, x.shape[0], bs):
x_bs = x[i : i + bs]
x_bs = self.module.encode(x_bs) * self.scale_factor
x_out.append(x_bs)
x = ops.cat(x_out, axis=0)

# x = rearrange(x, "(B T) C H W -> B C T H W", B=B)
x = self.rearrange_out(x, B=B)

return x

def decode(self, x, **kwargs):
# is_video = (x.ndim == 5)

B = x.shape[0]
# x: (B, Z, T, H, W)
# x = rearrange(x, "B Z T H W -> (B T) Z H W")
x = self.rearrange_in(x)

if self.micro_batch_size is None:
x = self.module.decode(x / self.scale_factor)
else:
# NOTE: cannot be used for training
bs = self.micro_batch_size
x_out = []
for i in range(0, x.shape[0], bs):
x_bs = x[i : i + bs]
x_bs = self.module.decode(x_bs / self.scale_factor)
x_out.append(x_bs)
x = ops.cat(x_out, axis=0)

# x = rearrange(x, "(B T) Z H W -> B Z T H W", B=B)
x = self.rearrange_out(x, B=B)

return x

def get_latent_size(self, input_size):
latent_size = []
for i in range(3):
# assert (
# input_size[i] is None or input_size[i] % self.patch_size[i] == 0
# ), "Input size must be divisible by patch size"
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
return latent_size


class VideoAutoencoderPipelineConfig(PretrainedConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's good to use PretrainedConfig with our project. It's better to move these parameters under __init__ in VideoAutoencoderPipeline.

model_type = "VideoAutoencoderPipeline"

def __init__(
self,
vae_2d=None,
vae_temporal=None,
from_pretrained=None,
freeze_vae_2d=False,
cal_loss=False,
micro_frame_size=None,
shift=0.0,
scale=1.0,
**kwargs,
):
self.vae_2d = vae_2d
self.vae_temporal = vae_temporal
self.from_pretrained = from_pretrained
self.freeze_vae_2d = freeze_vae_2d
self.cal_loss = cal_loss
self.micro_frame_size = micro_frame_size
self.shift = shift
self.scale = scale
super().__init__(**kwargs)


def build_module_from_config(config):
"""
config dict format:
- type: model class name
- others: model init args
"""
cfg = config.copy()
name = cfg.pop("type")
kwargs = cfg

# FIXME: use importlib with path
module = eval(name)(**kwargs)
return module


class VideoAutoencoderPipeline(nn.Cell):
"""
Main model for spatial vae + tempral vae
"""

# config_class = VideoAutoencoderPipelineConfig
def __init__(self, config: VideoAutoencoderPipelineConfig):
super().__init__()
self.spatial_vae = build_module_from_config(config.vae_2d)
self.temporal_vae = build_module_from_config(config.vae_temporal)

self.cal_loss = config.cal_loss
self.micro_frame_size = config.micro_frame_size
self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0]

if config.freeze_vae_2d:
for param in self.spatial_vae.get_parameters():
param.requires_grad = False

self.out_channels = self.temporal_vae.out_channels

# normalization parameters
scale = ms.Tensor(config.scale)
shift = ms.Tensor(config.shift)
if len(scale.shape) > 0:
scale = scale[None, :, None, None, None]
if len(shift.shape) > 0:
shift = shift[None, :, None, None, None]
self.scale = ms.Parameter(scale, requires_grad=False)
self.shift = ms.Parameter(shift, requires_grad=False)

def encode(self, x):
x_z = self.spatial_vae.encode(x)

if self.micro_frame_size is None:
posterior_mean, posterior_logvar = self.temporal_vae._encode(x_z)
z = self.temporal_vae.sample(posterior_mean, posterior_logvar)
else:
z_list = []
# TODO: there is a bug in torch impl. need to concat posterior as well. But ot save memory for concatnated posterior. Let's remain unchange.
for i in range(0, x_z.shape[2], self.micro_frame_size):
x_z_bs = x_z[:, :, i : i + self.micro_frame_size]
posterior_mean, posterior_logvar = self.temporal_vae._encode(x_z_bs)
z_bs = self.temporal_vae.sample(posterior_mean, posterior_logvar)
z_list.append(z_bs)
z = ops.cat(z_list, axis=2)
if self.cal_loss:
raise ValueError(
"Please fix the bug of posterior concatenation for temporal vae training with micro_frame_size"
)

if self.cal_loss:
return z, posterior_mean, posterior_logvar, x_z
else:
return (z - self.shift) / self.scale

def decode(self, z, num_frames=None):
if not self.cal_loss:
z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype)

if self.micro_frame_size is None:
x_z = self.temporal_vae.decode(z, num_frames=num_frames)
x = self.spatial_vae.decode(x_z)
else:
x_z_list = []
for i in range(0, z.shape[2], self.micro_z_frame_size):
z_bs = z[:, :, i : i + self.micro_z_frame_size]
x_z_bs = self.temporal_vae.decode(z_bs, num_frames=min(self.micro_frame_size, num_frames))
x_z_list.append(x_z_bs)
num_frames -= self.micro_frame_size
x_z = ops.cat(x_z_list, axis=2)
x = self.spatial_vae.decode(x_z)

if self.cal_loss:
return x, x_z
else:
return x

def construct(self, x):
assert self.cal_loss, "This method is only available when cal_loss is True"
z, posterior_mean, posterior_logvar, x_z = self.encode(x)
x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2])
return x_rec, x_z_rec, z, posterior_mean, posterior_logvar, x_z

def get_latent_size(self, input_size):
if self.micro_frame_size is None or input_size[0] is None:
return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size))
else:
sub_input_size = [self.micro_frame_size, input_size[1], input_size[2]]
sub_latent_size = self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(sub_input_size))
sub_latent_size[0] = sub_latent_size[0] * (input_size[0] // self.micro_frame_size)
remain_temporal_size = [input_size[0] % self.micro_frame_size, None, None]
if remain_temporal_size[0] > 0:
remain_size = self.temporal_vae.get_latent_size(remain_temporal_size)
sub_latent_size[0] += remain_size[0]
return sub_latent_size

def get_temporal_last_layer(self):
return self.temporal_vae.decoder.conv_out.conv.weight


def OpenSoraVAE_V1_2(
micro_batch_size=4,
micro_frame_size=17,
ckpt_path=None,
freeze_vae_2d=False,
cal_loss=False,
):
vae_2d = dict(
type="VideoAutoencoderKL",
config=SDXL_CONFIG,
micro_batch_size=micro_batch_size,
)
vae_temporal = dict(
type="VAE_Temporal_SD",
from_pretrained=None,
)
shift = (-0.10, 0.34, 0.27, 0.98)
scale = (3.85, 2.32, 2.33, 3.06)
kwargs = dict(
vae_2d=vae_2d,
vae_temporal=vae_temporal,
freeze_vae_2d=freeze_vae_2d,
cal_loss=cal_loss,
micro_frame_size=micro_frame_size,
shift=shift,
scale=scale,
)

config = VideoAutoencoderPipelineConfig(**kwargs)
model = VideoAutoencoderPipeline(config)

if ckpt_path is not None:
sd = ms.load_checkpoint(ckpt_path)
pu, cu = ms.load_param_into_net(model, sd, strict_load=False)
print(f"Net param not loaded : {pu}")
print(f"Checkpoint param not loaded : {cu}")

return model
Loading
Loading