Skip to content

Commit

Permalink
Merge branch 'load_ckpt' of https://github.com/sdbds/magic-animate-fo…
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds committed Dec 6, 2023
2 parents f8fb4cb + 1d435fa commit 7b4119e
Show file tree
Hide file tree
Showing 4 changed files with 734 additions and 239 deletions.
246 changes: 153 additions & 93 deletions demo/animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from magicanimate.models.controlnet import ControlNetModel
from magicanimate.models.appearance_encoder import AppearanceEncoderModel
from magicanimate.models.mutual_self_attention import ReferenceAttentionControl
from magicanimate.models.model_util import load_models
from magicanimate.pipelines.pipeline_animation import AnimationPipeline
from magicanimate.utils.util import save_videos_grid
from accelerate.utils import set_seed
Expand All @@ -42,57 +43,105 @@
import math
from pathlib import Path

class MagicAnimate():

class MagicAnimate:
def __init__(self, config="configs/prompts/animation.yaml") -> None:
print("Initializing MagicAnimate Pipeline...")
*_, func_args = inspect.getargvalues(inspect.currentframe())
func_args = dict(func_args)
config = OmegaConf.load(config)

config = OmegaConf.load(config)

inference_config = OmegaConf.load(config.inference_config)

motion_module = config.motion_module

### >>> create animation pipeline >>> ###
tokenizer = CLIPTokenizer.from_pretrained(config.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(config.pretrained_model_path, subfolder="text_encoder")
tokenizer, text_encoder, unet, noise_scheduler, vae = load_models(
config.pretrained_model_path,
scheduler_name="",
v2=False,
v_pred=False,
)
# tokenizer = CLIPTokenizer.from_pretrained(
# config.pretrained_model_path, subfolder="tokenizer"
# )
# text_encoder = CLIPTextModel.from_pretrained(
# config.pretrained_model_path, subfolder="text_encoder"
# )
if config.pretrained_unet_path:
unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_unet_path, unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
unet = UNet3DConditionModel.from_pretrained_2d(
config.pretrained_unet_path,
unet_additional_kwargs=OmegaConf.to_container(
inference_config.unet_additional_kwargs
),
)
else:
unet = UNet3DConditionModel.from_pretrained_2d(config.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
self.appearance_encoder = AppearanceEncoderModel.from_pretrained(config.pretrained_appearance_encoder_path, subfolder="appearance_encoder").cuda()
self.reference_control_writer = ReferenceAttentionControl(self.appearance_encoder, do_classifier_free_guidance=True, mode='write', fusion_blocks=config.fusion_blocks)
self.reference_control_reader = ReferenceAttentionControl(unet, do_classifier_free_guidance=True, mode='read', fusion_blocks=config.fusion_blocks)
unet = UNet3DConditionModel.from_pretrained_2d(
unet.config,
subfolder=None,
unet_additional_kwargs=OmegaConf.to_container(
inference_config.unet_additional_kwargs
),
)
self.appearance_encoder = AppearanceEncoderModel.from_pretrained(
config.pretrained_appearance_encoder_path, subfolder="appearance_encoder"
).cuda()
self.reference_control_writer = ReferenceAttentionControl(
self.appearance_encoder,
do_classifier_free_guidance=True,
mode="write",
fusion_blocks=config.fusion_blocks,
)
self.reference_control_reader = ReferenceAttentionControl(
unet,
do_classifier_free_guidance=True,
mode="read",
fusion_blocks=config.fusion_blocks,
)

if config.pretrained_vae_path:
vae = AutoencoderKL.from_pretrained(config.pretrained_vae_path)
else:
vae = AutoencoderKL.from_pretrained(config.pretrained_model_path, subfolder="vae")
# else:
# vae = AutoencoderKL.from_pretrained(
# config.pretrained_model_path, subfolder="vae"
# )

### Load controlnet
controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)
controlnet = ControlNetModel.from_pretrained(config.pretrained_controlnet_path)

vae.to(torch.float16)
unet.to(torch.float16)
text_encoder.to(torch.float16)
controlnet.to(torch.float16)
self.appearance_encoder.to(torch.float16)

unet.enable_xformers_memory_efficient_attention()
self.appearance_encoder.enable_xformers_memory_efficient_attention()
controlnet.enable_xformers_memory_efficient_attention()

self.pipeline = AnimationPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet,
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=controlnet,
scheduler=DDIMScheduler(
**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)
),
# NOTE: UniPCMultistepScheduler
).to("cuda")

# 1. unet ckpt
# 1.1 motion module
motion_module_state_dict = torch.load(motion_module, map_location="cpu")
if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
motion_module_state_dict = motion_module_state_dict['state_dict'] if 'state_dict' in motion_module_state_dict else motion_module_state_dict
if "global_step" in motion_module_state_dict:
func_args.update({"global_step": motion_module_state_dict["global_step"]})
motion_module_state_dict = (
motion_module_state_dict["state_dict"]
if "state_dict" in motion_module_state_dict
else motion_module_state_dict
)
try:
# extra steps for self-trained models
state_dict = OrderedDict()
Expand All @@ -104,14 +153,16 @@ def __init__(self, config="configs/prompts/animation.yaml") -> None:
state_dict[key] = motion_module_state_dict[key]
motion_module_state_dict = state_dict
del state_dict
missing, unexpected = self.pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
missing, unexpected = self.pipeline.unet.load_state_dict(
motion_module_state_dict, strict=False
)
assert len(unexpected) == 0
except:
_tmp_ = OrderedDict()
for key in motion_module_state_dict.keys():
if "motion_modules" in key:
if key.startswith("unet."):
_key = key.split('unet.')[-1]
_key = key.split("unet.")[-1]
_tmp_[_key] = motion_module_state_dict[key]
else:
_tmp_[key] = motion_module_state_dict[key]
Expand All @@ -122,74 +173,83 @@ def __init__(self, config="configs/prompts/animation.yaml") -> None:

self.pipeline.to("cuda")
self.L = config.L

print("Initialization Done!")

def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512):
prompt = n_prompt = ""
random_seed = int(random_seed)
step = int(step)
guidance_scale = float(guidance_scale)
samples_per_video = []
# manually set random seed for reproduction
if random_seed != -1:
torch.manual_seed(random_seed)
set_seed(random_seed)
else:
torch.seed()

if motion_sequence.endswith('.mp4'):
control = VideoReader(motion_sequence).read()
if control[0].shape[0] != size:
control = [np.array(Image.fromarray(c).resize((size, size))) for c in control]
control = np.array(control)

if source_image.shape[0] != size:
source_image = np.array(Image.fromarray(source_image).resize((size, size)))
H, W, C = source_image.shape

init_latents = None
original_length = control.shape[0]
if control.shape[0] % self.L > 0:
control = np.pad(control, ((0, self.L-control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)), mode='edge')
generator = torch.Generator(device=torch.device("cuda:0"))
generator.manual_seed(torch.initial_seed())
sample = self.pipeline(
prompt,
negative_prompt = n_prompt,
num_inference_steps = step,
guidance_scale = guidance_scale,
width = W,
height = H,
video_length = len(control),
controlnet_condition = control,
init_latents = init_latents,
generator = generator,
appearance_encoder = self.appearance_encoder,
reference_control_writer = self.reference_control_writer,
reference_control_reader = self.reference_control_reader,
source_image = source_image,
).videos

source_images = np.array([source_image] * original_length)
source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0
samples_per_video.append(source_images)

control = control / 255.0
control = rearrange(control, "t h w c -> 1 c t h w")
control = torch.from_numpy(control)
samples_per_video.append(control[:, :, :original_length])

samples_per_video.append(sample[:, :, :original_length])

samples_per_video = torch.cat(samples_per_video)

time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
savedir = f"demo/outputs"
animation_path = f"{savedir}/{time_str}.mp4"

os.makedirs(savedir, exist_ok=True)
save_videos_grid(samples_per_video, animation_path)

return animation_path


def __call__(
self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512
):
prompt = n_prompt = ""
random_seed = int(random_seed)
step = int(step)
guidance_scale = float(guidance_scale)
samples_per_video = []
# manually set random seed for reproduction
if random_seed != -1:
torch.manual_seed(random_seed)
set_seed(random_seed)
else:
torch.seed()

if motion_sequence.endswith(".mp4"):
control = VideoReader(motion_sequence).read()
if control[0].shape[0] != size:
control = [
np.array(Image.fromarray(c).resize((size, size))) for c in control
]
control = np.array(control)

if source_image.shape[0] != size:
source_image = np.array(Image.fromarray(source_image).resize((size, size)))
H, W, C = source_image.shape

init_latents = None
original_length = control.shape[0]
if control.shape[0] % self.L > 0:
control = np.pad(
control,
((0, self.L - control.shape[0] % self.L), (0, 0), (0, 0), (0, 0)),
mode="edge",
)
generator = torch.Generator(device=torch.device("cuda:0"))
generator.manual_seed(torch.initial_seed())
sample = self.pipeline(
prompt,
negative_prompt=n_prompt,
num_inference_steps=step,
guidance_scale=guidance_scale,
width=W,
height=H,
video_length=len(control),
controlnet_condition=control,
init_latents=init_latents,
generator=generator,
appearance_encoder=self.appearance_encoder,
reference_control_writer=self.reference_control_writer,
reference_control_reader=self.reference_control_reader,
source_image=source_image,
).videos

source_images = np.array([source_image] * original_length)
source_images = (
rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0
)
samples_per_video.append(source_images)

control = control / 255.0
control = rearrange(control, "t h w c -> 1 c t h w")
control = torch.from_numpy(control)
samples_per_video.append(control[:, :, :original_length])

samples_per_video.append(sample[:, :, :original_length])

samples_per_video = torch.cat(samples_per_video)

time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
savedir = f"demo/outputs"
animation_path = f"{savedir}/{time_str}.mp4"

os.makedirs(savedir, exist_ok=True)
save_videos_grid(samples_per_video, animation_path)

return animation_path
Loading

0 comments on commit 7b4119e

Please sign in to comment.