From d734e602f828a464797d1dcf111d276e95117b5b Mon Sep 17 00:00:00 2001 From: cocktailpeanut Date: Tue, 5 Dec 2023 02:08:11 -0500 Subject: [PATCH] debug mode --- demo/animate.py | 22 +++++++++++----------- demo/gradio_animate.py | 9 +++++---- demo/gradio_animate_dist.py | 2 +- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/demo/animate.py b/demo/animate.py index b71f1940..3545e7f8 100644 --- a/demo/animate.py +++ b/demo/animate.py @@ -125,7 +125,7 @@ def __init__(self, config="configs/prompts/animation.yaml") -> None: print("Initialization Done!") - def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, size=512): + def __call__(self, source_image, motion_sequence, random_seed, step, guidance_scale, debug, size=512): prompt = n_prompt = "" random_seed = int(random_seed) step = int(step) @@ -171,17 +171,17 @@ def __call__(self, source_image, motion_sequence, random_seed, step, guidance_sc source_image = source_image, ).videos - source_images = np.array([source_image] * original_length) - source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 - samples_per_video.append(source_images) - - control = control / 255.0 - control = rearrange(control, "t h w c -> 1 c t h w") - control = torch.from_numpy(control) - samples_per_video.append(control[:, :, :original_length]) + if debug: + source_images = np.array([source_image] * original_length) + source_images = rearrange(torch.from_numpy(source_images), "t h w c -> 1 c t h w") / 255.0 + samples_per_video.append(source_images) + + control = control / 255.0 + control = rearrange(control, "t h w c -> 1 c t h w") + control = torch.from_numpy(control) + samples_per_video.append(control[:, :, :original_length]) samples_per_video.append(sample[:, :, :original_length]) - samples_per_video = torch.cat(samples_per_video) time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") @@ -192,4 +192,4 @@ def __call__(self, source_image, motion_sequence, random_seed, step, guidance_sc save_videos_grid(samples_per_video, animation_path) return animation_path - \ No newline at end of file + diff --git a/demo/gradio_animate.py b/demo/gradio_animate.py index 10eaa133..918a7724 100644 --- a/demo/gradio_animate.py +++ b/demo/gradio_animate.py @@ -18,8 +18,8 @@ animator = MagicAnimate() -def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale): - return animator(reference_image, motion_sequence_state, seed, steps, guidance_scale) +def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale, debug): + return animator(reference_image, motion_sequence_state, seed, steps, guidance_scale, debug) with gr.Blocks() as demo: @@ -49,6 +49,7 @@ def animate(reference_image, motion_sequence_state, seed, steps, guidance_scale) random_seed = gr.Textbox(label="Random seed", value=1, info="default: -1") sampling_steps = gr.Textbox(label="Sampling steps", value=25, info="default: 25") guidance_scale = gr.Textbox(label="Guidance scale", value=7.5, info="default: 7.5") + debug = gr.Checkbox(label="Debug", value=True) submit = gr.Button("Animate") def read_video(video): @@ -74,7 +75,7 @@ def read_image(image, size=512): # when the `submit` button is clicked submit.click( animate, - [reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale], + [reference_image, motion_sequence, random_seed, sampling_steps, guidance_scale, debug], animation ) @@ -94,4 +95,4 @@ def read_image(image, size=512): ) -demo.launch(share=True) \ No newline at end of file +demo.launch(share=True) diff --git a/demo/gradio_animate_dist.py b/demo/gradio_animate_dist.py index bfb7a613..90c66345 100644 --- a/demo/gradio_animate_dist.py +++ b/demo/gradio_animate_dist.py @@ -116,4 +116,4 @@ def read_image(image, size=512): ) demo.queue(max_size=10) -demo.launch(share=True) \ No newline at end of file +demo.launch(share=True)