-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add hpcai OpenSora v1.2 - 3D VAE inference #560
Merged
Merged
Changes from 30 commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
269b8b6
add vae 3d enc-dec
SamitHuang c4de776
update test
SamitHuang 37f74a8
dev save
SamitHuang 5dff480
testing
SamitHuang c6122a8
spatial vae test pass
SamitHuang 2edee15
fix
SamitHuang dfbfdf4
Merge branch 'cai_os1.2' of github.com:SamitHuang/mindone into cai_os1.2
SamitHuang 3faeb25
add vae param list
SamitHuang 3f9ee09
fix name order
SamitHuang cce6cb5
Merge branch 'cai_os1.2' of github.com:SamitHuang/mindone into cai_os1.2
SamitHuang 50dbf03
add shape
SamitHuang 9e7e582
add shape
SamitHuang 2165403
Merge branch 'cai_os1.2' of github.com:SamitHuang/mindone into cai_os1.2
SamitHuang 7ab7508
order pnames
SamitHuang bf57f53
ordered temporal pnames
SamitHuang 733fa68
vae 3d recons ok
SamitHuang 3a32e29
update docs
SamitHuang 5eb4c6e
add test scripts
SamitHuang 4241c2d
add convert script
SamitHuang 9d5cc08
adapt to 910b
SamitHuang 543fbf6
support ms2.3 5d GN
SamitHuang 1fb8088
rm test files
SamitHuang 614884f
fix format
SamitHuang 5e5d3a9
debug infer
SamitHuang 74e6969
add sample t2v yaml
SamitHuang 0395b89
fix i2v
SamitHuang 314d65f
Merge branch 'cai_os1.2' of https://github.com/samithuang/mindone int…
SamitHuang 8429baa
update comment
SamitHuang ec1e5bf
fix format
SamitHuang 420d935
rm tmp test
SamitHuang 5f5ee1b
fix docs
SamitHuang be55c84
fix var name
SamitHuang b671293
fix latent shape compute
SamitHuang 2756d94
Merge branch 'pr_vae1.2' of https://github.com/samithuang/mindone int…
SamitHuang d94ec5e
add info
SamitHuang 6c3972b
fix image enc/dec
SamitHuang 44b5c56
fix format
SamitHuang cd547ed
adapt new vae in training
SamitHuang ba17c00
fix dtype
SamitHuang 0a067f9
pad bf16 fixed by cast to fp16
SamitHuang 1ea897e
fix ops.pad bf16 with fp32 cast
SamitHuang 32ed145
Merge branch 'master' into pr_vae1.2
SamitHuang 7bb2697
replace pad with concat
SamitHuang c42096c
fix conflict
SamitHuang 5507079
replace pad_at_dim with concat for bf16
SamitHuang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
30 changes: 30 additions & 0 deletions
30
examples/opensora_hpcai/configs/opensora-v1-2/inference/sample_t2v.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
model_version: v1.2 | ||
ckpt_path: "models/opensora_v1.2_stage3.ckpt" | ||
t5_model_dir: "models/t5-v1_1-xxl/" | ||
|
||
vae_model_type: "OpenSoraVAE_V1_2" | ||
vae_checkpoint: "models/OpenSora-VAE-v1.2/model.ckpt" | ||
vae_micro_batch_size: 4 | ||
vae_micro_frame_size: 17 | ||
|
||
image_size: [ 240, 320 ] | ||
num_frames: 24 | ||
frame_interval: 3 | ||
fps: 24 | ||
enable_flash_attention: True | ||
model_max_length: 200 | ||
dtype: "fp16" | ||
batch_size: 1 | ||
|
||
# sampling | ||
sampling_steps: 100 | ||
guidance_scale: 7.0 | ||
guidance_channels: 3 | ||
seed: 42 | ||
ddim_sampling: False | ||
|
||
loop: 1 | ||
condition_frame_length: 4 | ||
|
||
captions: | ||
- "Snow falling over multiple houses and trees on winter landscape against night sky. christmas festivity and celebration concept" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,13 @@ | ||
import logging | ||
import os | ||
|
||
from transformers import PretrainedConfig | ||
|
||
import mindspore as ms | ||
from mindspore import ops | ||
from mindspore import nn, ops | ||
|
||
from .autoencoder_kl import AutoencoderKL as AutoencoderKL_SD | ||
from .vae_temporal import VAE_Temporal_SD # noqa: F401 | ||
|
||
__all__ = ["AutoencoderKL"] | ||
|
||
|
@@ -49,3 +52,302 @@ def encode_with_moments_output(self, x): | |
std = self.exp(0.5 * logvar) | ||
|
||
return mean, std | ||
|
||
|
||
# -------------------------------- OpenSora v1.2 Begin ------------------------------------ # | ||
|
||
SDXL_CONFIG = SD_CONFIG.copy() | ||
SDXL_CONFIG.update({"resolution": 512}) | ||
|
||
|
||
class VideoAutoencoderKL(nn.Cell): | ||
""" | ||
Spatial VAE | ||
""" | ||
|
||
def __init__( | ||
self, | ||
config=SDXL_CONFIG, | ||
ckpt_path=None, | ||
micro_batch_size=None, | ||
): | ||
super().__init__() | ||
|
||
self.module = AutoencoderKL_SD( | ||
ddconfig=config, | ||
embed_dim=config["z_channels"], | ||
ckpt_path=ckpt_path, | ||
) | ||
|
||
self.out_channels = config["z_channels"] # self.module.config.latent_channels | ||
self.patch_size = (1, 8, 8) | ||
self.micro_batch_size = micro_batch_size | ||
|
||
# FIXME: "scaling_factor": 0.13025 is set in | ||
# https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers/blob/main/vae/config.json. | ||
# This is a mistake made during the training of OpenSora v1.2. | ||
# To re-use the trained model, we need to keep this mistake. | ||
# For training, we should refine to 0.13025. | ||
self.scale_factor = 0.18215 | ||
|
||
@staticmethod | ||
def rearrange_in(x): | ||
B, C, T, H, W = x.shape | ||
# (b c t h w) -> (b t c h w) | ||
x = ops.transpose(x, (0, 2, 1, 3, 4)) | ||
x = ops.reshape(x, (B * T, C, H, W)) | ||
|
||
return x | ||
|
||
@staticmethod | ||
def rearrange_out(x, B): | ||
# x = rearrange(x, "(B T) C H W -> B C T H W", B=B) | ||
BT, C, H, W = x.shape | ||
T = BT // B | ||
x = ops.reshape(x, (B, T, C, H, W)) | ||
x = ops.transpose(x, (0, 2, 1, 3, 4)) | ||
|
||
return x | ||
|
||
def encode(self, x): | ||
# x: (B, C, T, H, W) | ||
# NOTE: remind to use stop gradient when invoke it | ||
# is_video = (x.ndim == 5) | ||
|
||
B = x.shape[0] | ||
# x = rearrange(x, "B C T H W -> (B T) C H W") | ||
x = self.rearrange_in(x) | ||
|
||
if self.micro_batch_size is None: | ||
# x = self.module.encode(x).latent_dist.sample().mul_(0.18215) | ||
x = self.module.encode(x) * self.scale_factor | ||
else: | ||
bs = self.micro_batch_size | ||
x_out = [] | ||
# FIXME: supported in graph mode? or use split | ||
for i in range(0, x.shape[0], bs): | ||
x_bs = x[i : i + bs] | ||
x_bs = self.module.encode(x_bs) * self.scale_factor | ||
x_out.append(x_bs) | ||
x = ops.cat(x_out, axis=0) | ||
|
||
# x = rearrange(x, "(B T) C H W -> B C T H W", B=B) | ||
x = self.rearrange_out(x, B=B) | ||
|
||
return x | ||
|
||
def decode(self, x, **kwargs): | ||
# is_video = (x.ndim == 5) | ||
|
||
B = x.shape[0] | ||
# x: (B, Z, T, H, W) | ||
# x = rearrange(x, "B Z T H W -> (B T) Z H W") | ||
x = self.rearrange_in(x) | ||
|
||
if self.micro_batch_size is None: | ||
x = self.module.decode(x / self.scale_factor) | ||
else: | ||
# NOTE: cannot be used for training | ||
bs = self.micro_batch_size | ||
x_out = [] | ||
for i in range(0, x.shape[0], bs): | ||
x_bs = x[i : i + bs] | ||
x_bs = self.module.decode(x_bs / self.scale_factor) | ||
x_out.append(x_bs) | ||
x = ops.cat(x_out, axis=0) | ||
|
||
# x = rearrange(x, "(B T) Z H W -> B Z T H W", B=B) | ||
x = self.rearrange_out(x, B=B) | ||
|
||
return x | ||
|
||
def get_latent_size(self, input_size): | ||
latent_size = [] | ||
for i in range(3): | ||
# assert ( | ||
# input_size[i] is None or input_size[i] % self.patch_size[i] == 0 | ||
# ), "Input size must be divisible by patch size" | ||
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None) | ||
return latent_size | ||
|
||
|
||
class VideoAutoencoderPipelineConfig(PretrainedConfig): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it's good to use |
||
model_type = "VideoAutoencoderPipeline" | ||
|
||
def __init__( | ||
self, | ||
vae_2d=None, | ||
vae_temporal=None, | ||
from_pretrained=None, | ||
freeze_vae_2d=False, | ||
cal_loss=False, | ||
micro_frame_size=None, | ||
shift=0.0, | ||
scale=1.0, | ||
**kwargs, | ||
): | ||
self.vae_2d = vae_2d | ||
self.vae_temporal = vae_temporal | ||
self.from_pretrained = from_pretrained | ||
self.freeze_vae_2d = freeze_vae_2d | ||
self.cal_loss = cal_loss | ||
self.micro_frame_size = micro_frame_size | ||
self.shift = shift | ||
self.scale = scale | ||
super().__init__(**kwargs) | ||
|
||
|
||
def build_module_from_config(config): | ||
""" | ||
config dict format: | ||
- type: model class name | ||
- others: model init args | ||
""" | ||
cfg = config.copy() | ||
name = cfg.pop("type") | ||
kwargs = cfg | ||
|
||
# FIXME: use importlib with path | ||
module = eval(name)(**kwargs) | ||
return module | ||
|
||
|
||
class VideoAutoencoderPipeline(nn.Cell): | ||
""" | ||
Main model for spatial vae + tempral vae | ||
""" | ||
|
||
# config_class = VideoAutoencoderPipelineConfig | ||
def __init__(self, config: VideoAutoencoderPipelineConfig): | ||
super().__init__() | ||
self.spatial_vae = build_module_from_config(config.vae_2d) | ||
self.temporal_vae = build_module_from_config(config.vae_temporal) | ||
|
||
self.cal_loss = config.cal_loss | ||
self.micro_frame_size = config.micro_frame_size | ||
self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0] | ||
|
||
if config.freeze_vae_2d: | ||
for param in self.spatial_vae.get_parameters(): | ||
param.requires_grad = False | ||
|
||
self.out_channels = self.temporal_vae.out_channels | ||
|
||
# normalization parameters | ||
scale = ms.Tensor(config.scale) | ||
shift = ms.Tensor(config.shift) | ||
if len(scale.shape) > 0: | ||
scale = scale[None, :, None, None, None] | ||
if len(shift.shape) > 0: | ||
shift = shift[None, :, None, None, None] | ||
self.scale = ms.Parameter(scale, requires_grad=False) | ||
self.shift = ms.Parameter(shift, requires_grad=False) | ||
|
||
def encode(self, x): | ||
x_z = self.spatial_vae.encode(x) | ||
|
||
if self.micro_frame_size is None: | ||
posterior_mean, posterior_logvar = self.temporal_vae._encode(x_z) | ||
z = self.temporal_vae.sample(posterior_mean, posterior_logvar) | ||
else: | ||
z_list = [] | ||
# TODO: there is a bug in torch impl. need to concat posterior as well. But ot save memory for concatnated posterior. Let's remain unchange. | ||
for i in range(0, x_z.shape[2], self.micro_frame_size): | ||
x_z_bs = x_z[:, :, i : i + self.micro_frame_size] | ||
posterior_mean, posterior_logvar = self.temporal_vae._encode(x_z_bs) | ||
z_bs = self.temporal_vae.sample(posterior_mean, posterior_logvar) | ||
z_list.append(z_bs) | ||
z = ops.cat(z_list, axis=2) | ||
if self.cal_loss: | ||
raise ValueError( | ||
"Please fix the bug of posterior concatenation for temporal vae training with micro_frame_size" | ||
) | ||
|
||
if self.cal_loss: | ||
return z, posterior_mean, posterior_logvar, x_z | ||
else: | ||
return (z - self.shift) / self.scale | ||
|
||
def decode(self, z, num_frames=None): | ||
if not self.cal_loss: | ||
z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype) | ||
|
||
if self.micro_frame_size is None: | ||
x_z = self.temporal_vae.decode(z, num_frames=num_frames) | ||
x = self.spatial_vae.decode(x_z) | ||
else: | ||
x_z_list = [] | ||
for i in range(0, z.shape[2], self.micro_z_frame_size): | ||
z_bs = z[:, :, i : i + self.micro_z_frame_size] | ||
x_z_bs = self.temporal_vae.decode(z_bs, num_frames=min(self.micro_frame_size, num_frames)) | ||
x_z_list.append(x_z_bs) | ||
num_frames -= self.micro_frame_size | ||
x_z = ops.cat(x_z_list, axis=2) | ||
x = self.spatial_vae.decode(x_z) | ||
|
||
if self.cal_loss: | ||
return x, x_z | ||
else: | ||
return x | ||
|
||
def construct(self, x): | ||
assert self.cal_loss, "This method is only available when cal_loss is True" | ||
z, posterior_mean, posterior_logvar, x_z = self.encode(x) | ||
x_rec, x_z_rec = self.decode(z, num_frames=x_z.shape[2]) | ||
return x_rec, x_z_rec, z, posterior_mean, posterior_logvar, x_z | ||
|
||
def get_latent_size(self, input_size): | ||
if self.micro_frame_size is None or input_size[0] is None: | ||
return self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(input_size)) | ||
else: | ||
sub_input_size = [self.micro_frame_size, input_size[1], input_size[2]] | ||
sub_latent_size = self.temporal_vae.get_latent_size(self.spatial_vae.get_latent_size(sub_input_size)) | ||
sub_latent_size[0] = sub_latent_size[0] * (input_size[0] // self.micro_frame_size) | ||
remain_temporal_size = [input_size[0] % self.micro_frame_size, None, None] | ||
if remain_temporal_size[0] > 0: | ||
remain_size = self.temporal_vae.get_latent_size(remain_temporal_size) | ||
sub_latent_size[0] += remain_size[0] | ||
return sub_latent_size | ||
|
||
def get_temporal_last_layer(self): | ||
return self.temporal_vae.decoder.conv_out.conv.weight | ||
|
||
|
||
def OpenSoraVAE_V1_2( | ||
micro_batch_size=4, | ||
micro_frame_size=17, | ||
ckpt_path=None, | ||
freeze_vae_2d=False, | ||
cal_loss=False, | ||
): | ||
vae_2d = dict( | ||
type="VideoAutoencoderKL", | ||
config=SDXL_CONFIG, | ||
micro_batch_size=micro_batch_size, | ||
) | ||
vae_temporal = dict( | ||
type="VAE_Temporal_SD", | ||
from_pretrained=None, | ||
) | ||
shift = (-0.10, 0.34, 0.27, 0.98) | ||
scale = (3.85, 2.32, 2.33, 3.06) | ||
kwargs = dict( | ||
vae_2d=vae_2d, | ||
vae_temporal=vae_temporal, | ||
freeze_vae_2d=freeze_vae_2d, | ||
cal_loss=cal_loss, | ||
micro_frame_size=micro_frame_size, | ||
shift=shift, | ||
scale=scale, | ||
) | ||
|
||
config = VideoAutoencoderPipelineConfig(**kwargs) | ||
model = VideoAutoencoderPipeline(config) | ||
|
||
if ckpt_path is not None: | ||
sd = ms.load_checkpoint(ckpt_path) | ||
pu, cu = ms.load_param_into_net(model, sd, strict_load=False) | ||
print(f"Net param not loaded : {pu}") | ||
print(f"Checkpoint param not loaded : {cu}") | ||
|
||
return model |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.