diff --git a/demo/animate.py b/demo/animate.py index b71f1940..a4a13a7a 100644 --- a/demo/animate.py +++ b/demo/animate.py @@ -29,6 +29,7 @@ from magicanimate.models.controlnet import ControlNetModel from magicanimate.models.appearance_encoder import AppearanceEncoderModel from magicanimate.models.mutual_self_attention import ReferenceAttentionControl +from magicanimate.models.model_util import load_models from magicanimate.pipelines.pipeline_animation import AnimationPipeline from magicanimate.utils.util import save_videos_grid from accelerate.utils import set_seed @@ -42,57 +43,104 @@ import math from pathlib import Path -class MagicAnimate(): + +class MagicAnimate: def __init__(self, config="configs/prompts/animation.yaml") -> None: print("Initializing MagicAnimate Pipeline...") *_, func_args = inspect.getargvalues(inspect.currentframe()) func_args = dict(func_args) - - config = OmegaConf.load(config) - + + config = OmegaConf.load(config) + inference_config = OmegaConf.load(config.inference_config) - + motion_module = config.motion_module - + ### >>> create animation pipeline >>> ### - tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer") - text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") + tokenizer, text_encoder, unet, noise_scheduler, vae = load_models( + config.pretrained_model_path, + scheduler_name="", + v2=False, + v_pred=False, + ) + # tokenizer = CLIPTokenizer.from_pretrained( + # config.pretrained_model_path, subfolder="tokenizer" + # ) + # text_encoder = CLIPTextModel.from_pretrained( + # config.pretrained_model_path, subfolder="text_encoder" + # ) if config.pretrained_unet_path: - unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) + unet = UNet3DConditionModel.from_pretrained_2d( + config.pretrained_unet_path, + unet_additional_kwargs=OmegaConf.to_container( + inference_config.unet_additional_kwargs + ), + ) else: - unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) - self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").cuda() - self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks) - self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks) + unet = UNet3DConditionModel.from_pretrained_2d( + unet, + subfolder=None, + unet_additional_kwargs=OmegaConf.to_container( + inference_config.unet_additional_kwargs + ), + ) + self.appearance_encoder = AppearanceEncoderModel.from_pretrained( + config.pretrained_appearance_encoder_path, subfolder="appearance_encoder" + ).cuda() + self.reference_control_writer = ReferenceAttentionControl( + self.appearance_encoder, + do_classifier_free_guidance=True, + mode="write", + fusion_blocks=config.fusion_blocks, + ) + self.reference_control_reader = ReferenceAttentionControl( + unet, + do_classifier_free_guidance=True, + mode="read", + fusion_blocks=config.fusion_blocks, + ) if config.pretrained_vae_path is not None: vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path) - else: - vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae") + # else: + # vae = AutoencoderKL.from_pretrained( + # config.pretrained_model_path, subfolder="vae" + # ) ### Load controlnet - controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path) + controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path) vae.to(torch.float16) unet.to(torch.float16) text_encoder.to(torch.float16) controlnet.to(torch.float16) self.appearance_encoder.to(torch.float16) - + unet.enable_xformers_memory_efficient_attention() self.appearance_encoder.enable_xformers_memory_efficient_attention() controlnet.enable_xformers_memory_efficient_attention() self.pipeline = AnimationPipeline( - vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, - scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=DDIMScheduler( + **OmegaConf.to_container(inference_config.noise_scheduler_kwargs) + ), # NOTE: UniPCMultistepScheduler ).to("cuda") # 1. unet ckpt # 1.1 motion module motion_module_state_dict = torch.load(motion_module, map_location="cpu") - if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) - motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict + if "global_step" in motion_module_state_dict: + func_args.update({"global_step": motion_module_state_dict["global_step"]}) + motion_module_state_dict = ( + motion_module_state_dict["state_dict"] + if "state_dict" in motion_module_state_dict + else motion_module_state_dict + ) try: # extra steps for self-trained models state_dict = OrderedDict() @@ -104,14 +152,16 @@ def __init__(self, config="configs/prompts/animation.yaml") -> None: state_dict[key] = motion_module_state_dict[key] motion_module_state_dict = state_dict del state_dict - missing, unexpected = self.pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) + missing, unexpected = self.pipeline.unet.load_state_dict( + motion_module_state_dict, strict=False + ) assert len(unexpected) == 0 except: _tmp_ = OrderedDict() for key in motion_module_state_dict.keys(): if "motion_modules" in key: if key.startswith("unet."): - _key = key.split('unet.')[-1] + _key = key.split("unet.")[-1] _tmp_[_key] = motion_module_state_dict[key] else: _tmp_[key] = motion_module_state_dict[key] @@ -122,74 +172,83 @@ def __init__(self, config="configs/prompts/animation.yaml") -> None: self.pipeline.to("cuda") self.L = config.L - + print("Initialization Done!") - - def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512): - prompt = n_prompt = "" - random_seed = int(random_seed) - step = int(step) - guidance_scale = float(guidance_scale) - samples_per_video = [] - # manually set random seed for reproduction - if random_seed != -1: - torch.manual_seed(random_seed) - set_seed(random_seed) - else: - torch.seed() - - if motion_sequence.endswith('.mp4'): - control = VideoReader(motion_sequence).read() - if control[0].shape[0] != size: - control = [np.array(Image.fromarray(c).resize((size, size))) for c in control] - control = np.array(control) - - if source_image.shape[0] != size: - source_image = np.array(Image.fromarray(source_image).resize((size, size))) - H, W, C = source_image.shape - - init_latents = None - original_length = control.shape[0] - if control.shape[0] % self.L > 0: - control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge') - generator = torch.Generator(device=torch.device("cuda:0")) - generator.manual_seed(torch.initial_seed()) - sample = self.pipeline( - prompt, - negative_prompt = n_prompt, - num_inference_steps = step, - guidance_scale = guidance_scale, - width = W, - height = H, - video_length = len(control), - controlnet_condition = control, - init_latents = init_latents, - generator = generator, - appearance_encoder = self.appearance_encoder, - reference_control_writer = self.reference_control_writer, - reference_control_reader = self.reference_control_reader, - source_image = source_image, - ).videos - - source_images = np.array([source_image] * original_length) - source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 - samples_per_video.append(source_images) - - control = control / 255.0 - control = rearrange(control, "t h w c -> 1 c t h w") - control = torch.from_numpy(control) - samples_per_video.append(control[:, :, :original_length]) - - samples_per_video.append(sample[:, :, :original_length]) - - samples_per_video = torch.cat(samples_per_video) - - time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - savedir = f"demo/outputs" - animation_path = f"{savedir}/{time_str}.mp4" - - os.makedirs(savedir, exist_ok=True) - save_videos_grid(samples_per_video, animation_path) - - return animation_path - \ No newline at end of file + + def __call__( + self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512 + ): + prompt = n_prompt = "" + random_seed = int(random_seed) + step = int(step) + guidance_scale = float(guidance_scale) + samples_per_video = [] + # manually set random seed for reproduction + if random_seed != -1: + torch.manual_seed(random_seed) + set_seed(random_seed) + else: + torch.seed() + + if motion_sequence.endswith(".mp4"): + control = VideoReader(motion_sequence).read() + if control[0].shape[0] != size: + control = [ + np.array(Image.fromarray(c).resize((size, size))) for c in control + ] + control = np.array(control) + + if source_image.shape[0] != size: + source_image = np.array(Image.fromarray(source_image).resize((size, size))) + H, W, C = source_image.shape + + init_latents = None + original_length = control.shape[0] + if control.shape[0] % self.L > 0: + control = np.pad( + control, + ((0, self.L - control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), + mode="edge", + ) + generator = torch.Generator(device=torch.device("cuda:0")) + generator.manual_seed(torch.initial_seed()) + sample = self.pipeline( + prompt, + negative_prompt=n_prompt, + num_inference_steps=step, + guidance_scale=guidance_scale, + width=W, + height=H, + video_length=len(control), + controlnet_condition=control, + init_latents=init_latents, + generator=generator, + appearance_encoder=self.appearance_encoder, + reference_control_writer=self.reference_control_writer, + reference_control_reader=self.reference_control_reader, + source_image=source_image, + ).videos + + source_images = np.array([source_image] * original_length) + source_images = ( + rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 + ) + samples_per_video.append(source_images) + + control = control / 255.0 + control = rearrange(control, "t h w c -> 1 c t h w") + control = torch.from_numpy(control) + samples_per_video.append(control[:, :, :original_length]) + + samples_per_video.append(sample[:, :, :original_length]) + + samples_per_video = torch.cat(samples_per_video) + + time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + savedir = f"demo/outputs" + animation_path = f"{savedir}/{time_str}.mp4" + + os.makedirs(savedir, exist_ok=True) + save_videos_grid(samples_per_video, animation_path) + + return animation_path diff --git a/magicanimate/models/model_util.py b/magicanimate/models/model_util.py new file mode 100644 index 00000000..eed89c28 --- /dev/null +++ b/magicanimate/models/model_util.py @@ -0,0 +1,265 @@ +from typing import Literal, Union, Optional + +import torch +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection +from diffusers import ( + UNet2DConditionModel, + SchedulerMixin, + StableDiffusionPipeline, + StableDiffusionXLPipeline, + AutoencoderKL, +) +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + LMSDiscreteScheduler, + EulerAncestralDiscreteScheduler, + UniPCMultistepScheduler, +) + +from omegaconf import OmegaConf + +TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" +TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" + +AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a", "uniPC"] + +SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] + +DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this + + +def load_diffusers_model( + pretrained_model_name_or_path: str, + v2: bool = False, + clip_skip: Optional[int] = None, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: + if v2: + tokenizer = CLIPTokenizer.from_pretrained( + TOKENIZER_V2_MODEL_NAME, + subfolder="tokenizer", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + # default is clip skip 2 + num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + else: + tokenizer = CLIPTokenizer.from_pretrained( + TOKENIZER_V1_MODEL_NAME, + subfolder="tokenizer", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="unet", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + + return tokenizer, text_encoder, unet, vae + + +def load_checkpoint_model( + checkpoint_path: str, + v2: bool = False, + clip_skip: Optional[int] = None, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: + pipe = StableDiffusionPipeline.from_single_file( + checkpoint_path, + upcast_attention=True if v2 else False, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + unet = pipe.unet + tokenizer = pipe.tokenizer + text_encoder = pipe.text_encoder + vae = pipe.vae + if clip_skip is not None: + if v2: + text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1) + else: + text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1) + + del pipe + + return tokenizer, text_encoder, unet, vae + + +def load_models( + pretrained_model_name_or_path: str, + scheduler_name: str, + v2: bool = False, + v_pred: bool = False, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]: + if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + tokenizer, text_encoder, unet, vae = load_checkpoint_model( + pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype + ) + else: # diffusers + tokenizer, text_encoder, unet, vae = load_diffusers_model( + pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype + ) + + if scheduler_name: + scheduler = create_noise_scheduler( + scheduler_name, + prediction_type="v_prediction" if v_pred else "epsilon", + ) + else: + scheduler = None + + return tokenizer, text_encoder, unet, scheduler, vae + + +def load_diffusers_model_xl( + pretrained_model_name_or_path: str, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: + # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet + + tokenizers = [ + CLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ), + CLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer_2", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + pad_token_id=0, # same as open clip + ), + ] + + text_encoders = [ + CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ), + CLIPTextModelWithProjection.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder_2", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ), + ] + + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="unet", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + return tokenizers, text_encoders, unet, vae + + +def load_checkpoint_model_xl( + checkpoint_path: str, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: + pipe = StableDiffusionXLPipeline.from_single_file( + checkpoint_path, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + unet = pipe.unet + vae = pipe.vae + tokenizers = [pipe.tokenizer, pipe.tokenizer_2] + text_encoders = [pipe.text_encoder, pipe.text_encoder_2] + if len(text_encoders) == 2: + text_encoders[1].pad_token_id = 0 + + del pipe + + return tokenizers, text_encoders, unet, vae + + +def load_models_xl( + pretrained_model_name_or_path: str, + scheduler_name: str, + weight_dtype: torch.dtype = torch.float32, + noise_scheduler_kwargs=None, +) -> tuple[ + list[CLIPTokenizer], + list[SDXL_TEXT_ENCODER_TYPE], + UNet2DConditionModel, + SchedulerMixin, +]: + if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + (tokenizers, text_encoders, unet, vae) = load_checkpoint_model_xl( + pretrained_model_name_or_path, weight_dtype + ) + else: # diffusers + (tokenizers, text_encoders, unet, vae) = load_diffusers_model_xl( + pretrained_model_name_or_path, weight_dtype + ) + if scheduler_name: + scheduler = create_noise_scheduler(scheduler_name, noise_scheduler_kwargs) + else: + scheduler = None + + return tokenizers, text_encoders, unet, scheduler, vae + + +def create_noise_scheduler( + scheduler_name: AVAILABLE_SCHEDULERS = "ddpm", + noise_scheduler_kwargs=None, + prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", +) -> SchedulerMixin: + name = scheduler_name.lower().replace(" ", "_") + if name.lower() == "ddim": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim + scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) + elif name.lower() == "ddpm": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm + scheduler = DDPMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) + elif name.lower() == "lms": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete + scheduler = LMSDiscreteScheduler( + **OmegaConf.to_container(noise_scheduler_kwargs) + ) + elif name.lower() == "euler_a": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral + scheduler = EulerAncestralDiscreteScheduler( + **OmegaConf.to_container(noise_scheduler_kwargs) + ) + elif name.lower() == "unipc": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/unipc + scheduler = UniPCMultistepScheduler( + **OmegaConf.to_container(noise_scheduler_kwargs) + ) + else: + raise ValueError(f"Unknown scheduler name: {name}") + + return scheduler diff --git a/magicanimate/models/unet_controlnet.py b/magicanimate/models/unet_controlnet.py index 0ccd9cad..14b09c39 100644 --- a/magicanimate/models/unet_controlnet.py +++ b/magicanimate/models/unet_controlnet.py @@ -1,7 +1,7 @@ # ************************************************************************* # This file may have been modified by Bytedance Inc. (“Bytedance Inc.'s Mo- # difications”). All Bytedance Inc.'s Modifications are Copyright (2023) B- -# ytedance Inc.. +# ytedance Inc.. # ************************************************************************* # Copyright 2023 The HuggingFace Team. All rights reserved. @@ -29,6 +29,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin +from diffusers.loaders import UNet2DConditionLoadersMixin from diffusers.utils import BaseOutput, logging from diffusers.models.embeddings import TimestepEmbedding, Timesteps from magicanimate.models.unet_3d_blocks import ( @@ -51,7 +52,7 @@ class UNet3DConditionOutput(BaseOutput): sample: torch.FloatTensor -class UNet3DConditionModel(ModelMixin, ConfigMixin): +class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): _supports_gradient_checkpointing = True @register_to_config @@ -62,7 +63,7 @@ def __init__( out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, - freq_shift: int = 0, + freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", @@ -74,7 +75,7 @@ def __init__( "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", - "CrossAttnUpBlock3D" + "CrossAttnUpBlock3D", ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), @@ -92,16 +93,15 @@ def __init__( num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", - # Additional - use_motion_module = False, - motion_module_resolutions = ( 1,2,4,8 ), - motion_module_mid_block = False, - motion_module_decoder_only = False, - motion_module_type = None, - motion_module_kwargs = {}, - unet_use_cross_frame_attention = None, - unet_use_temporal_attention = None, + use_motion_module=False, + motion_module_resolutions=(1, 2, 4, 8), + motion_module_mid_block=False, + motion_module_decoder_only=False, + motion_module_type=None, + motion_module_kwargs={}, + unet_use_cross_frame_attention=None, + unet_use_temporal_attention=None, ): super().__init__() @@ -109,7 +109,9 @@ def __init__( time_embed_dim = block_out_channels[0] * 4 # input - self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + self.conv_in = InflatedConv3d( + in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1) + ) # time self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) @@ -140,7 +142,7 @@ def __init__( # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): - res = 2 ** i + res = 2**i input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 @@ -163,11 +165,11 @@ def __init__( only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, - unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, - - use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), + use_motion_module=use_motion_module + and (res in motion_module_resolutions) + and (not motion_module_decoder_only), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) @@ -188,17 +190,15 @@ def __init__( dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, - unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, - use_motion_module=use_motion_module and motion_module_mid_block, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") - + # count how many layers upsample the videos self.num_upsamplers = 0 @@ -213,7 +213,9 @@ def __init__( prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] # add upsample block for all BUT final layer if not is_final_block: @@ -240,11 +242,10 @@ def __init__( only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, - unet_use_cross_frame_attention=unet_use_cross_frame_attention, unet_use_temporal_attention=unet_use_temporal_attention, - - use_motion_module=use_motion_module and (res in motion_module_resolutions), + use_motion_module=use_motion_module + and (res in motion_module_resolutions), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) @@ -252,9 +253,13 @@ def __init__( prev_output_channel = output_channel # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) self.conv_act = nn.SiLU() - self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + self.conv_out = InflatedConv3d( + block_out_channels[0], out_channels, kernel_size=3, padding=1 + ) def set_attention_slice(self, slice_size): r""" @@ -293,7 +298,11 @@ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): # make smallest slice possible slice_size = num_slicable_layers * [1] - slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + slice_size = ( + num_slicable_layers * [slice_size] + if not isinstance(slice_size, list) + else slice_size + ) if len(slice_size) != len(sliceable_head_dims): raise ValueError( @@ -310,7 +319,9 @@ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + def fn_recursive_set_attention_slice( + module: torch.nn.Module, slice_size: List[int] + ): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) @@ -322,7 +333,9 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + if isinstance( + module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D) + ): module.gradient_checkpointing = value def forward( @@ -399,7 +412,9 @@ def forward( if self.class_embedding is not None: if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") + raise ValueError( + "class_labels should be provided when num_class_embeds > 0" + ) if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) @@ -411,11 +426,17 @@ def forward( sample = self.conv_in(sample) # down - is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_controlnet = ( + mid_block_additional_residual is not None + and down_block_additional_residuals is not None + ) down_block_res_samples = (sample,) for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): sample, res_samples = downsample_block( hidden_states=sample, temb=emb, @@ -423,7 +444,11 @@ def forward( attention_mask=attention_mask, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) down_block_res_samples += res_samples @@ -433,14 +458,21 @@ def forward( for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): - down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples = new_down_block_res_samples + ( + down_block_res_sample, + ) down_block_res_samples = new_down_block_res_samples # mid sample = self.mid_block( - sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, ) if is_controlnet: @@ -451,14 +483,19 @@ def forward( is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): sample = upsample_block( hidden_states=sample, temb=emb, @@ -469,7 +506,11 @@ def forward( ) else: sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + encoder_hidden_states=encoder_hidden_states, ) # post-process @@ -483,43 +524,60 @@ def forward( return UNet3DConditionOutput(sample=sample) @classmethod - def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): - if subfolder is not None: - pretrained_model_path = os.path.join(pretrained_model_path, subfolder) - print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...") - - config_file = os.path.join(pretrained_model_path, 'config.json') - if not os.path.isfile(config_file): - raise RuntimeError(f"{config_file} does not exist") - with open(config_file, "r") as f: - config = json.load(f) - config["_class_name"] = cls.__name__ - config["down_block_types"] = [ - "CrossAttnDownBlock3D", - "CrossAttnDownBlock3D", - "CrossAttnDownBlock3D", - "DownBlock3D" - ] - config["up_block_types"] = [ - "UpBlock3D", - "CrossAttnUpBlock3D", - "CrossAttnUpBlock3D", - "CrossAttnUpBlock3D" - ] - # config["mid_block_type"] = "UNetMidBlock3DCrossAttn" - - from diffusers.utils import WEIGHTS_NAME - model = cls.from_config(config, **unet_additional_kwargs) - model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) - if not os.path.isfile(model_file): - raise RuntimeError(f"{model_file} does not exist") - state_dict = torch.load(model_file, map_location="cpu") - - m, u = model.load_state_dict(state_dict, strict=False) - print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") - # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n") - - params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()] - print(f"### Temporal Module Parameters: {sum(params) / 1e6} M") - + def from_pretrained_2d( + cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None + ): + if type(pretrained_model_path) == str: + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + print( + f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..." + ) + + config_file = os.path.join(pretrained_model_path, "config.json") + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ] + config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ] + # config["mid_block_type"] = "UNetMidBlock3DCrossAttn" + + from diffusers.utils import WEIGHTS_NAME + + model = cls.from_config(config, **unet_additional_kwargs) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n") + + params = [ + p.numel() if "temporal" in n else 0 for n, p in model.named_parameters() + ] + print(f"### Temporal Module Parameters: {sum(params) / 1e6} M") + else: + state_dict = pretrained_model_path + m, u = model.load_state_dict(state_dict, strict=False) + print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n") + + params = [ + p.numel() if "temporal" in n else 0 for n, p in model.named_parameters() + ] + return model diff --git a/magicanimate/pipelines/animation.py b/magicanimate/pipelines/animation.py index 899583ed..33ed132c 100644 --- a/magicanimate/pipelines/animation.py +++ b/magicanimate/pipelines/animation.py @@ -31,6 +31,7 @@ from magicanimate.models.controlnet import ControlNetModel from magicanimate.models.appearance_encoder import AppearanceEncoderModel from magicanimate.models.mutual_self_attention import ReferenceAttentionControl +from magicanimate.models.model_util import load_models from magicanimate.pipelines.pipeline_animation import AnimationPipeline from magicanimate.utils.util import save_videos_grid from magicanimate.utils.dist_tools import distributed_init @@ -44,50 +45,85 @@ def main(args): - *_, func_args = inspect.getargvalues(inspect.currentframe()) func_args = dict(func_args) - - config = OmegaConf.load(args.config) - + + config = OmegaConf.load(args.config) + # Initialize distributed training device = torch.device(f"cuda:{args.rank}") - dist_kwargs = {"rank":args.rank, "world_size":args.world_size, "dist":args.dist} - + dist_kwargs = {"rank": args.rank, "world_size": args.world_size, "dist": args.dist} + if config.savename is None: time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") savedir = f"samples/{Path(args.config).stem}-{time_str}" else: savedir = f"samples/{config.savename}" - + if args.dist: dist.broadcast_object_list([savedir], 0) dist.barrier() - + if args.rank == 0: os.makedirs(savedir, exist_ok=True) inference_config = OmegaConf.load(config.inference_config) - + motion_module = config.motion_module - + ### >>> create animation pipeline >>> ### - tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer") - text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder") + tokenizer, text_encoder, unet, noise_scheduler, vae = load_models( + config.pretrained_model_path, + scheduler_name="", + v2=False, + v_pred=False, + ) + unet. + # tokenizer = CLIPTokenizer.from_pretrained( + # config.pretrained_model_path, subfolder="tokenizer" + # ) + # text_encoder = CLIPTextModel.from_pretrained( + # config.pretrained_model_path, subfolder="text_encoder" + # ) if config.pretrained_unet_path: - unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) + unet = UNet3DConditionModel.from_pretrained_2d( + config.pretrained_unet_path, + unet_additional_kwargs=OmegaConf.to_container( + inference_config.unet_additional_kwargs + ), + ) else: - unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) - appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").to(device) - reference_control_writer = ReferenceAttentionControl(appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks) - reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks) + unet = UNet3DConditionModel.from_pretrained_2d( + unet, + subfolder=None, + unet_additional_kwargs=OmegaConf.to_container( + inference_config.unet_additional_kwargs + ), + ) + appearance_encoder = AppearanceEncoderModel.from_pretrained( + config.pretrained_appearance_encoder_path, subfolder="appearance_encoder" + ).to(device) + reference_control_writer = ReferenceAttentionControl( + appearance_encoder, + do_classifier_free_guidance=True, + mode="write", + fusion_blocks=config.fusion_blocks, + ) + reference_control_reader = ReferenceAttentionControl( + unet, + do_classifier_free_guidance=True, + mode="read", + fusion_blocks=config.fusion_blocks, + ) if config.pretrained_vae_path is not None: vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path) - else: - vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae") + # else: + # vae = AutoencoderKL.from_pretrained( + # config.pretrained_model_path, subfolder="vae" + # ) ### Load controlnet - controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path) + controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path) unet.enable_xformers_memory_efficient_attention() appearance_encoder.enable_xformers_memory_efficient_attention() @@ -100,16 +136,27 @@ def main(args): controlnet.to(torch.float16) pipeline = AnimationPipeline( - vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, - scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=DDIMScheduler( + **OmegaConf.to_container(inference_config.noise_scheduler_kwargs) + ), # NOTE: UniPCMultistepScheduler ) # 1. unet ckpt # 1.1 motion module motion_module_state_dict = torch.load(motion_module, map_location="cpu") - if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) - motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict + if "global_step" in motion_module_state_dict: + func_args.update({"global_step": motion_module_state_dict["global_step"]}) + motion_module_state_dict = ( + motion_module_state_dict["state_dict"] + if "state_dict" in motion_module_state_dict + else motion_module_state_dict + ) try: # extra steps for self-trained models state_dict = OrderedDict() @@ -121,14 +168,16 @@ def main(args): state_dict[key] = motion_module_state_dict[key] motion_module_state_dict = state_dict del state_dict - missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) + missing, unexpected = pipeline.unet.load_state_dict( + motion_module_state_dict, strict=False + ) assert len(unexpected) == 0 except: _tmp_ = OrderedDict() for key in motion_module_state_dict.keys(): if "motion_modules" in key: if key.startswith("unet."): - _key = key.split('unet.')[-1] + _key = key.split("unet.")[-1] _tmp_[_key] = motion_module_state_dict[key] else: _tmp_[key] = motion_module_state_dict[key] @@ -139,13 +188,19 @@ def main(args): pipeline.to(device) ### <<< create validation pipeline <<< ### - + random_seeds = config.get("seed", [-1]) - random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) - random_seeds = random_seeds * len(config.source_image) if len(random_seeds) == 1 else random_seeds - + random_seeds = ( + [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) + ) + random_seeds = ( + random_seeds * len(config.source_image) + if len(random_seeds) == 1 + else random_seeds + ) + # input test videos (either source video/ conditions) - + test_videos = config.video_path source_images = config.source_image num_actual_inference_steps = config.get("num_actual_inference_steps", config.steps) @@ -157,88 +212,113 @@ def main(args): config.random_seed = [] prompt = n_prompt = "" for idx, (source_image, test_video, random_seed, size, step) in tqdm( - enumerate(zip(source_images, test_videos, random_seeds, sizes, steps)), - total=len(test_videos), - disable=(args.rank!=0) + enumerate(zip(source_images, test_videos, random_seeds, sizes, steps)), + total=len(test_videos), + disable=(args.rank != 0), ): samples_per_video = [] samples_per_clip = [] # manually set random seed for reproduction - if random_seed != -1: + if random_seed != -1: torch.manual_seed(random_seed) set_seed(random_seed) else: torch.seed() config.random_seed.append(torch.initial_seed()) - if test_video.endswith('.mp4'): + if test_video.endswith(".mp4"): control = VideoReader(test_video).read() if control[0].shape[0] != size: - control = [np.array(Image.fromarray(c).resize((size, size))) for c in control] + control = [ + np.array(Image.fromarray(c).resize((size, size))) for c in control + ] if config.max_length is not None: - control = control[config.offset: (config.offset+config.max_length)] + control = control[config.offset : (config.offset + config.max_length)] control = np.array(control) - + if source_image.endswith(".mp4"): - source_image = np.array(Image.fromarray(VideoReader(source_image).read()[0]).resize((size, size))) + source_image = np.array( + Image.fromarray(VideoReader(source_image).read()[0]).resize( + (size, size) + ) + ) else: source_image = np.array(Image.open(source_image).resize((size, size))) H, W, C = source_image.shape - + print(f"current seed: {torch.initial_seed()}") init_latents = None - + # print(f"sampling {prompt} ...") original_length = control.shape[0] if control.shape[0] % config.L > 0: - control = np.pad(control, ((0, config.L-control.shape[0] % config.L), (0, 0), (0, 0), (0, 0)), mode='edge') + control = np.pad( + control, + ((0, config.L - control.shape[0] % config.L), (0, 0), (0, 0), (0, 0)), + mode="edge", + ) generator = torch.Generator(device=torch.device("cuda:0")) generator.manual_seed(torch.initial_seed()) sample = pipeline( prompt, - negative_prompt = n_prompt, - num_inference_steps = config.steps, - guidance_scale = config.guidance_scale, - width = W, - height = H, - video_length = len(control), - controlnet_condition = control, - init_latents = init_latents, - generator = generator, - num_actual_inference_steps = num_actual_inference_steps, - appearance_encoder = appearance_encoder, - reference_control_writer = reference_control_writer, - reference_control_reader = reference_control_reader, - source_image = source_image, + negative_prompt=n_prompt, + num_inference_steps=config.steps, + guidance_scale=config.guidance_scale, + width=W, + height=H, + video_length=len(control), + controlnet_condition=control, + init_latents=init_latents, + generator=generator, + num_actual_inference_steps=num_actual_inference_steps, + appearance_encoder=appearance_encoder, + reference_control_writer=reference_control_writer, + reference_control_reader=reference_control_reader, + source_image=source_image, **dist_kwargs, ).videos if args.rank == 0: source_images = np.array([source_image] * original_length) - source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 + source_images = ( + rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") + / 255.0 + ) samples_per_video.append(source_images) - + control = control / 255.0 control = rearrange(control, "t h w c -> 1 c t h w") control = torch.from_numpy(control) samples_per_video.append(control[:, :, :original_length]) samples_per_video.append(sample[:, :, :original_length]) - + samples_per_video = torch.cat(samples_per_video) video_name = os.path.basename(test_video)[:-4] source_name = os.path.basename(config.source_image[idx]).split(".")[0] - save_videos_grid(samples_per_video[-1:], f"{savedir}/videos/{source_name}_{video_name}.mp4") - save_videos_grid(samples_per_video, f"{savedir}/videos/{source_name}_{video_name}/grid.mp4") + save_videos_grid( + samples_per_video[-1:], + f"{savedir}/videos/{source_name}_{video_name}.mp4", + ) + save_videos_grid( + samples_per_video, + f"{savedir}/videos/{source_name}_{video_name}/grid.mp4", + ) if config.save_individual_videos: - save_videos_grid(samples_per_video[1:2], f"{savedir}/videos/{source_name}_{video_name}/ctrl.mp4") - save_videos_grid(samples_per_video[0:1], f"{savedir}/videos/{source_name}_{video_name}/orig.mp4") - + save_videos_grid( + samples_per_video[1:2], + f"{savedir}/videos/{source_name}_{video_name}/ctrl.mp4", + ) + save_videos_grid( + samples_per_video[0:1], + f"{savedir}/videos/{source_name}_{video_name}/orig.mp4", + ) + if args.dist: dist.barrier() - + if args.rank == 0: OmegaConf.save(config, f"{savedir}/config.yaml") @@ -254,7 +334,6 @@ def distributed_main(device_id, args): def run(args): - if args.dist: args.world_size = max(1, torch.cuda.device_count()) assert args.world_size <= torch.cuda.device_count()