diff --git a/README.md b/README.md index d9b3714c..8454fc9f 100755 --- a/README.md +++ b/README.md @@ -9,22 +9,23 @@ - We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes: - **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object. - To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency. + - To run the community-build gradio demo locally, run `python -m scripts.demo.gradio_app_sv4d`. - Please check our [project page](https://sv4d.github.io), [tech report](https://sv4d.github.io/static/sv4d_technical_report.pdf) and [video summary](https://www.youtube.com/watch?v=RBP8vdAWTgk) for more details. -**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [sv4d.safetensors](https://huggingface.co/stabilityai/sv4d) and [sv3d_u.safetensors](https://huggingface.co/stabilityai/sv3d) from HuggingFace into `checkpoints/`) +**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [sv4d.safetensors](https://huggingface.co/stabilityai/sv4d) and [sv3d_u.safetensors](https://huggingface.co/stabilityai/sv3d) from HuggingFace into `checkpoints/`) To run **SV4D** on a single input video of 21 frames: - Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/` - Run `python scripts/sampling/simple_video_sample_4d.py --input_path ` - `input_path` : The input video `` can be - - a single video file in `gif` or `mp4` format, such as `assets/test_video1.mp4`, or + - a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/test_video1.mp4`, or - a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or - a file name pattern matching images of video frames. - `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time. - `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p. - - `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0` - - **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D. - - **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`. + - `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path assets/sv4d_videos/test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0` + - **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D. + - **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`. ![tile](assets/sv4d.gif) diff --git a/assets/sv4d_videos/bunnyman.mp4 b/assets/sv4d_videos/bunnyman.mp4 new file mode 100644 index 00000000..40a5b27d Binary files /dev/null and b/assets/sv4d_videos/bunnyman.mp4 differ diff --git a/assets/sv4d_videos/dolphin.mp4 b/assets/sv4d_videos/dolphin.mp4 new file mode 100644 index 00000000..e80f4601 Binary files /dev/null and b/assets/sv4d_videos/dolphin.mp4 differ diff --git a/assets/sv4d_videos/green_robot.mp4 b/assets/sv4d_videos/green_robot.mp4 new file mode 100644 index 00000000..428049c6 Binary files /dev/null and b/assets/sv4d_videos/green_robot.mp4 differ diff --git a/assets/sv4d_videos/guppie_v0.mp4 b/assets/sv4d_videos/guppie_v0.mp4 new file mode 100644 index 00000000..76764d5c Binary files /dev/null and b/assets/sv4d_videos/guppie_v0.mp4 differ diff --git a/assets/hiphop_parrot.mp4 b/assets/sv4d_videos/hiphop_parrot.mp4 similarity index 100% rename from assets/hiphop_parrot.mp4 rename to assets/sv4d_videos/hiphop_parrot.mp4 diff --git a/assets/sv4d_videos/human5.mp4 b/assets/sv4d_videos/human5.mp4 new file mode 100644 index 00000000..1962b5e6 Binary files /dev/null and b/assets/sv4d_videos/human5.mp4 differ diff --git a/assets/sv4d_videos/human7.mp4 b/assets/sv4d_videos/human7.mp4 new file mode 100644 index 00000000..3d8bb72b Binary files /dev/null and b/assets/sv4d_videos/human7.mp4 differ diff --git a/assets/sv4d_videos/lucia_v000.mp4 b/assets/sv4d_videos/lucia_v000.mp4 new file mode 100644 index 00000000..83f4d9e9 Binary files /dev/null and b/assets/sv4d_videos/lucia_v000.mp4 differ diff --git a/assets/sv4d_videos/monkey.mp4 b/assets/sv4d_videos/monkey.mp4 new file mode 100644 index 00000000..5434a246 Binary files /dev/null and b/assets/sv4d_videos/monkey.mp4 differ diff --git a/assets/sv4d_videos/pistol_v0.mp4 b/assets/sv4d_videos/pistol_v0.mp4 new file mode 100644 index 00000000..83c0e982 Binary files /dev/null and b/assets/sv4d_videos/pistol_v0.mp4 differ diff --git a/assets/sv4d_videos/snowboard_v000.mp4 b/assets/sv4d_videos/snowboard_v000.mp4 new file mode 100644 index 00000000..5c1b67a4 Binary files /dev/null and b/assets/sv4d_videos/snowboard_v000.mp4 differ diff --git a/assets/sv4d_videos/stroller_v000.mp4 b/assets/sv4d_videos/stroller_v000.mp4 new file mode 100644 index 00000000..e293bd5c Binary files /dev/null and b/assets/sv4d_videos/stroller_v000.mp4 differ diff --git a/assets/test_video1.mp4 b/assets/sv4d_videos/test_video1.mp4 similarity index 100% rename from assets/test_video1.mp4 rename to assets/sv4d_videos/test_video1.mp4 diff --git a/assets/test_video2.mp4 b/assets/sv4d_videos/test_video2.mp4 similarity index 100% rename from assets/test_video2.mp4 rename to assets/sv4d_videos/test_video2.mp4 diff --git a/assets/sv4d_videos/train_v0.mp4 b/assets/sv4d_videos/train_v0.mp4 new file mode 100644 index 00000000..cb5f76f3 Binary files /dev/null and b/assets/sv4d_videos/train_v0.mp4 differ diff --git a/assets/sv4d_videos/wave_hello.mp4 b/assets/sv4d_videos/wave_hello.mp4 new file mode 100644 index 00000000..4c7693f5 Binary files /dev/null and b/assets/sv4d_videos/wave_hello.mp4 differ diff --git a/scripts/demo/gradio_app_sv4d.py b/scripts/demo/gradio_app_sv4d.py new file mode 100644 index 00000000..d4d8b37f --- /dev/null +++ b/scripts/demo/gradio_app_sv4d.py @@ -0,0 +1,496 @@ +# Adding this at the very top of app.py to make 'generative-models' directory discoverable +import os +import sys + +sys.path.append(os.path.join(os.path.dirname(__file__), "generative-models")) + +from glob import glob +from typing import Optional + +import gradio as gr +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from typing import List, Optional, Union +import torchvision + +from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder +from scripts.demo.sv4d_helpers import ( + decode_latents, + load_model, + initial_model_load, + read_video, + run_img2vid, + prepare_inputs, + do_sample_per_step, + sample_sv3d, + save_video, + preprocess_video, +) + + +# the tmp path, if /tmp/gradio is not writable, change it to a writable path +# os.environ["GRADIO_TEMP_DIR"] = "gradio_tmp" + +version = "sv4d" # replace with 'sv3d_p' or 'sv3d_u' for other models + +# Define the repo, local directory and filename +repo_id = "stabilityai/sv4d" +filename = f"{version}.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors" +local_dir = "checkpoints" +local_ckpt_path = os.path.join(local_dir, filename) + +# Check if the file already exists +if not os.path.exists(local_ckpt_path): + # If the file doesn't exist, download it + hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir) + print("File downloaded. (sv4d)") +else: + print("File already exists. No need to download. (sv4d)") + +device = "cuda" +max_64_bit_int = 2**63 - 1 + +num_frames = 21 +num_steps = 20 +model_config = f"scripts/sampling/configs/{version}.yaml" + +# Set model config +T = 5 # number of frames per sample +V = 8 # number of views per sample +F = 8 # vae factor to downsize image->latent +C = 4 +H, W = 576, 576 +n_frames = 21 # number of input and output video frames +n_views = V + 1 # number of output video views (1 input view + 8 novel views) +n_views_sv3d = 21 +subsampled_views = np.array( + [0, 2, 5, 7, 9, 12, 14, 16, 19] +) # subsample (V+1=)9 (uniform) views from 21 SV3D views + +version_dict = { + "T": T * V, + "H": H, + "W": W, + "C": C, + "f": F, + "options": { + "discretization": 1, + "cfg": 3, + "sigma_min": 0.002, + "sigma_max": 700.0, + "rho": 7.0, + "guider": 5, + "num_steps": num_steps, + "force_uc_zero_embeddings": [ + "cond_frames", + "cond_frames_without_noise", + "cond_view", + "cond_motion", + ], + "additional_guider_kwargs": { + "additional_cond_keys": ["cond_view", "cond_motion"] + }, + }, +} + +# Load SV4D model +model, filter = load_model( + model_config, + device, + version_dict["T"], + num_steps, +) +model = initial_model_load(model) + +# -----------sv3d config and model loading---------------- +# if version == "sv3d_u": +sv3d_model_config = "scripts/sampling/configs/sv3d_u.yaml" +# elif version == "sv3d_p": +# sv3d_model_config = "scripts/sampling/configs/sv3d_p.yaml" +# else: +# raise ValueError(f"Version {version} does not exist.") + +# Define the repo, local directory and filename +repo_id = "stabilityai/sv3d" +filename = f"sv3d_u.safetensors" # replace with "sv3d_u.safetensors" or "sv3d_p.safetensors" +local_dir = "checkpoints" +local_ckpt_path = os.path.join(local_dir, filename) + +# Check if the file already exists +if not os.path.exists(local_ckpt_path): + # If the file doesn't exist, download it + hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir) + print("File downloaded. (sv3d)") +else: + print("File already exists. No need to download. (sv3d)") + +# load sv3d model +sv3d_model, filter = load_model( + sv3d_model_config, + device, + 21, + num_steps, + verbose=False, +) +sv3d_model = initial_model_load(sv3d_model) +# ------------------ + +def sample_anchor( + input_path: str = "assets/test_image.png", # Can either be image file or folder with image files + seed: Optional[int] = None, + encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary. + decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. + num_steps: int = 20, + sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p + fps_id: int = 6, + motion_bucket_id: int = 127, + cond_aug: float = 1e-5, + device: str = "cuda", + elevations_deg: Optional[Union[float, List[float]]] = 10.0, + azimuths_deg: Optional[List[float]] = None, + verbose: Optional[bool] = False, +): + """ + Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each + image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. + """ + output_folder = os.path.dirname(input_path) + + torch.manual_seed(seed) + os.makedirs(output_folder, exist_ok=True) + + # Read input video frames i.e. images at view 0 + print(f"Reading {input_path}") + images_v0 = read_video( + input_path, + n_frames=n_frames, + device=device, + ) + + # Get camera viewpoints + if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): + elevations_deg = [elevations_deg] * n_views_sv3d + assert ( + len(elevations_deg) == n_views_sv3d + ), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}" + if azimuths_deg is None: + azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360 + assert ( + len(azimuths_deg) == n_views_sv3d + ), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}" + polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg]) + azimuths_rad = np.array( + [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg] + ) + + # Sample multi-view images of the first frame using SV3D i.e. images at time 0 + sv3d_model.sampler.num_steps = num_steps + print("sv3d_model.sampler.num_steps", sv3d_model.sampler.num_steps) + images_t0 = sample_sv3d( + images_v0[0], + n_views_sv3d, + num_steps, + sv3d_version, + fps_id, + motion_bucket_id, + cond_aug, + decoding_t, + device, + polars_rad, + azimuths_rad, + verbose, + sv3d_model, + ) + images_t0 = torch.roll(images_t0, 1, 0) # move conditioning image to first frame + + sv3d_file = os.path.join(output_folder, "t000.mp4") + save_video(sv3d_file, images_t0.unsqueeze(1)) + + for emb in model.conditioner.embedders: + if isinstance(emb, VideoPredictionEmbedderWithEncoder): + emb.en_and_decode_n_samples_a_time = encoding_t + model.en_and_decode_n_samples_a_time = decoding_t + # Initialize image matrix + img_matrix = [[None] * n_views for _ in range(n_frames)] + for i, v in enumerate(subsampled_views): + img_matrix[0][i] = images_t0[v].unsqueeze(0) + for t in range(n_frames): + img_matrix[t][0] = images_v0[t] + + # Interleaved sampling for anchor frames + t0, v0 = 0, 0 + frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20] + view_indices = np.arange(V) + 1 + print(f"Sampling anchor frames {frame_indices}") + image = img_matrix[t0][v0] + cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0) + cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0) + polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() + azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() + azims = (azims - azimuths_rad[v0]) % (torch.pi * 2) + model.sampler.num_steps = num_steps + version_dict["options"]["num_steps"] = num_steps + samples = run_img2vid( + version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t + ) + samples = samples.view(T, V, 3, H, W) + for i, t in enumerate(frame_indices): + for j, v in enumerate(view_indices): + if img_matrix[t][v] is None: + img_matrix[t][v] = samples[i, j][None] * 2 - 1 + + # concat video + grid_list = [] + for t in frame_indices: + imgs_view = torch.cat(img_matrix[t]) + grid_list.append(torchvision.utils.make_grid(imgs_view, nrow=3).unsqueeze(0)) + # save output videos + anchor_vis_file = os.path.join(output_folder, "anchor_vis.mp4") + save_video(anchor_vis_file, grid_list, fps=3) + anchor_file = os.path.join(output_folder, "anchor.mp4") + image_list = samples.view(T*V, 3, H, W).unsqueeze(1) * 2 - 1 + save_video(anchor_file, image_list) + + return sv3d_file, anchor_vis_file, anchor_file + + +def sample_all( + input_path: str = "inputs/test_video1.mp4", # Can either be video file or folder with image files + sv3d_path: str = "outputs/sv4d/000000_t000.mp4", + anchor_path: str = "outputs/sv4d/000000_anchor.mp4", + seed: Optional[int] = None, + num_steps: int = 20, + device: str = "cuda", + elevations_deg: Optional[Union[float, List[float]]] = 10.0, + azimuths_deg: Optional[List[float]] = None, +): + """ + Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each + image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. + """ + output_folder = os.path.dirname(input_path) + torch.manual_seed(seed) + os.makedirs(output_folder, exist_ok=True) + + # Read input video frames i.e. images at view 0 + print(f"Reading {input_path}") + images_v0 = read_video( + input_path, + n_frames=n_frames, + device=device, + ) + + images_t0 = read_video( + sv3d_path, + n_frames=n_views_sv3d, + device=device, + ) + + # Get camera viewpoints + if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): + elevations_deg = [elevations_deg] * n_views_sv3d + assert ( + len(elevations_deg) == n_views_sv3d + ), f"Please provide 1 value, or a list of {n_views_sv3d} values for elevations_deg! Given {len(elevations_deg)}" + if azimuths_deg is None: + azimuths_deg = np.linspace(0, 360, n_views_sv3d + 1)[1:] % 360 + assert ( + len(azimuths_deg) == n_views_sv3d + ), f"Please provide a list of {n_views_sv3d} values for azimuths_deg! Given {len(azimuths_deg)}" + polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg]) + azimuths_rad = np.array( + [np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg] + ) + + # Initialize image matrix + img_matrix = [[None] * n_views for _ in range(n_frames)] + for i, v in enumerate(subsampled_views): + img_matrix[0][i] = images_t0[v] + for t in range(n_frames): + img_matrix[t][0] = images_v0[t] + + # load interleaved sampling for anchor frames + t0, v0 = 0, 0 + frame_indices = np.arange(T - 1, n_frames, T - 1) # [4, 8, 12, 16, 20] + view_indices = np.arange(V) + 1 + + anchor_frames = read_video( + anchor_path, + n_frames=T * V, + device=device, + ) + anchor_frames = torch.cat(anchor_frames).view(T, V, 3, H, W) + for i, t in enumerate(frame_indices): + for j, v in enumerate(view_indices): + if img_matrix[t][v] is None: + img_matrix[t][v] = anchor_frames[i, j][None] + + # Dense sampling for the rest + print(f"Sampling dense frames:") + for t0 in np.arange(0, n_frames - 1, T - 1): # [0, 4, 8, 12, 16] + frame_indices = t0 + np.arange(T) + print(f"Sampling dense frames {frame_indices}") + latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda") + + polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() + azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() + azims = (azims - azimuths_rad[v0]) % (torch.pi * 2) + + # alternate between forward and backward conditioning + forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs( + frame_indices, + img_matrix, + v0, + view_indices, + model, + version_dict, + seed, + polars, + azims + ) + + for step in range(num_steps): + if step % 2 == 1: + c, uc, additional_model_inputs, sampler = forward_inputs + frame_indices = forward_frame_indices + else: + c, uc, additional_model_inputs, sampler = backward_inputs + frame_indices = backward_frame_indices + noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1) + + samples = do_sample_per_step( + model, + sampler, + noisy_latents, + c, + uc, + step, + additional_model_inputs, + ) + samples = samples.view(T, V, C, H // F, W // F) + for i, t in enumerate(frame_indices): + for j, v in enumerate(view_indices): + latent_matrix[t, v] = samples[i, j] + + img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T) + + + # concat video + grid_list = [] + for t in range(n_frames): + imgs_view = torch.cat(img_matrix[t]) + grid_list.append(torchvision.utils.make_grid(imgs_view, nrow=3).unsqueeze(0)) + # save output videos + vid_file = os.path.join(output_folder, "sv4d_final.mp4") + save_video(vid_file, grid_list) + + return vid_file, seed + + +with gr.Blocks() as demo: + gr.Markdown( + """# Demo for SV4D from Stability AI ([model](https://huggingface.co/stabilityai/sv4d), [news](https://stability.ai/news/stable-video-4d)) +#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/sv4d/blob/main/LICENSE.md)): generate 8 novel view videos from a single-view video (with white background). +#### It takes ~45s to generate anchor frames and another ~160s to generate full results (21 frames). +#### Hints for improving performance: +- Use a white background; +- Make the object in the center of the image; +- The SV4D process the first 21 frames of the uploaded video. Gradio provides a nice option of trimming the uploaded video if needed. + """ + ) + with gr.Row(): + with gr.Column(): + input_video = gr.Video(label="Upload your video") + generate_btn = gr.Button("Step 1: generate 8 novel view videos (5 anchor frames each)") + interpolate_btn = gr.Button("Step 2: Extend novel view videos to 21 frames") + with gr.Column(): + anchor_video = gr.Video(label="SV4D outputs (anchor frames)") + sv3d_video = gr.Video(label="SV3D outputs", interactive=False) + with gr.Column(): + sv4d_interpolated_video = gr.Video(label="SV4D outputs (21 frames)") + + with gr.Accordion("Advanced options", open=False): + seed = gr.Slider( + label="Seed", + value=23, + # randomize=True, + minimum=0, + maximum=100, + step=1, + ) + encoding_t = gr.Slider( + label="Encode n frames at a time", + info="Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.", + value=8, + minimum=1, + maximum=40, + ) + decoding_t = gr.Slider( + label="Decode n frames at a time", + info="Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.", + value=4, + minimum=1, + maximum=14, + ) + denoising_steps = gr.Slider( + label="Number of denoising steps", + info="Increase will improve the performance but needs more time.", + value=20, + minimum=10, + maximum=50, + step=1, + ) + remove_bg = gr.Checkbox( + label="Remove background", + info="We use rembg. Users can check the alternative way: SAM2 (https://github.com/facebookresearch/segment-anything-2)", + ) + + input_video.upload(fn=preprocess_video, inputs=[input_video, remove_bg], outputs=input_video, queue=False) + + with gr.Row(visible=False): + anchor_frames = gr.Video() + + generate_btn.click( + fn=sample_anchor, + inputs=[input_video, seed, encoding_t, decoding_t, denoising_steps], + outputs=[sv3d_video, anchor_video, anchor_frames], + api_name="SV4D output (5 frames)", + ) + + interpolate_btn.click( + fn=sample_all, + inputs=[input_video, sv3d_video, anchor_frames, seed, denoising_steps], + outputs=[sv4d_interpolated_video, seed], + api_name="SV4D interpolation (21 frames)", + ) + + examples = gr.Examples( + fn=preprocess_video, + examples=[ + "./assets/sv4d_videos/test_video1.mp4", + "./assets/sv4d_videos/test_video2.mp4", + "./assets/sv4d_videos/green_robot.mp4", + "./assets/sv4d_videos/dolphin.mp4", + "./assets/sv4d_videos/lucia_v000.mp4", + "./assets/sv4d_videos/snowboard_v000.mp4", + "./assets/sv4d_videos/stroller_v000.mp4", + "./assets/sv4d_videos/human5.mp4", + "./assets/sv4d_videos/bunnyman.mp4", + "./assets/sv4d_videos/hiphop_parrot.mp4", + "./assets/sv4d_videos/guppie_v0.mp4", + "./assets/sv4d_videos/wave_hello.mp4", + "./assets/sv4d_videos/pistol_v0.mp4", + "./assets/sv4d_videos/human7.mp4", + "./assets/sv4d_videos/monkey.mp4", + "./assets/sv4d_videos/train_v0.mp4", + ], + inputs=[input_video], + run_on_click=True, + outputs=[input_video], + ) + +if __name__ == "__main__": + demo.queue(max_size=20) + demo.launch(share=True) + \ No newline at end of file diff --git a/scripts/demo/sv4d_helpers.py b/scripts/demo/sv4d_helpers.py index d67231df..7296c58c 100755 --- a/scripts/demo/sv4d_helpers.py +++ b/scripts/demo/sv4d_helpers.py @@ -121,10 +121,6 @@ def save_video(file_name, imgs, fps=10): def read_video( input_path: str, n_frames: int, - W: int, - H: int, - remove_bg: bool = False, - image_frame_ratio: Optional[float] = None, device: str = "cuda", ): path = Path(input_path) @@ -158,46 +154,120 @@ def read_video( if len(images) < n_frames: images = (images + images[::-1])[:n_frames] - if len(images) != n_frames: raise ValueError(f"Input video contains fewer than {n_frames} frames.") - # Remove background and crop video frames images_v0 = [] - for t, image in enumerate(images): + + for image in images: + image = ToTensor()(image).unsqueeze(0).to(device) + images_v0.append(image * 2.0 - 1.0) + return images_v0 + + +def preprocess_video(input_path, remove_bg=False, n_frames=21, W=576, H=576, output_folder=None, image_frame_ratio = 0.917): + print(f"preprocess {input_path}") + if output_folder is None: + output_folder = os.path.dirname(input_path) + path = Path(input_path) + is_video_file = False + all_img_paths = [] + if path.is_file(): + if any([input_path.endswith(x) for x in [".gif", ".mp4"]]): + is_video_file = True + else: + raise ValueError("Path is not a valid video file.") + elif path.is_dir(): + all_img_paths = sorted( + [ + f + for f in path.iterdir() + if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] + ] + )[:n_frames] + elif "*" in input_path: + all_img_paths = sorted(glob(input_path))[:n_frames] + else: + raise ValueError + + if is_video_file and input_path.endswith(".gif"): + images = read_gif(input_path, n_frames)[:n_frames] + elif is_video_file and input_path.endswith(".mp4"): + images = read_mp4(input_path, n_frames)[:n_frames] + else: + print(f"Loading {len(all_img_paths)} video frames...") + images = [Image.open(img_path) for img_path in all_img_paths] + + if len(images) != n_frames: + raise ValueError(f"Input video contains {len(images)} frames, fewer than {n_frames} frames.") + + # Remove background + for i, image in enumerate(images): if remove_bg: - if image.mode != "RGBA": - image.thumbnail([W, H], Image.Resampling.LANCZOS) + if image.mode == "RGBA": + pass + else: + # image.thumbnail([W, H], Image.Resampling.LANCZOS) image = remove(image.convert("RGBA"), alpha_matting=True) - image_arr = np.array(image) - in_w, in_h = image_arr.shape[:2] + images[i] = image + + # Crop video frames, assume the object is already in the center of the image + white_thresh = 250 + images_v0 = [] + box_coord = [np.inf, np.inf, 0, 0] + for image in images: + image_arr = np.array(image) + in_w, in_h = image_arr.shape[:2] + original_center = (in_w // 2, in_h // 2) + if image.mode == "RGBA": ret, mask = cv2.threshold( np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY ) - x, y, w, h = cv2.boundingRect(mask) - max_size = max(w, h) - if t == 0: - side_len = ( - int(max_size / image_frame_ratio) - if image_frame_ratio is not None - else in_w - ) - padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) - center = side_len // 2 - padded_image[ - center - h // 2 : center - h // 2 + h, - center - w // 2 : center - w // 2 + w, - ] = image_arr[y : y + h, x : x + w] - rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS) - rgba_arr = np.array(rgba) / 255.0 - rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) - image = Image.fromarray((rgb * 255).astype(np.uint8)) else: - image = image.convert("RGB").resize((W, H), Image.LANCZOS) - image = ToTensor()(image).unsqueeze(0).to(device) - images_v0.append(image * 2.0 - 1.0) - return images_v0 - + # assume the input image has white background + ret, mask = cv2.threshold( + (np.array(image).mean(-1) <= white_thresh).astype(np.uint8) * 255, 0, 255, cv2.THRESH_BINARY + ) + + x, y, w, h = cv2.boundingRect(mask) + box_coord[0] = min(box_coord[0], x) + box_coord[1] = min(box_coord[1], y) + box_coord[2] = max(box_coord[2], x + w) + box_coord[3] = max(box_coord[3], y + h) + box_square = max(original_center[0] - box_coord[0], original_center[1] - box_coord[1]) + box_square = max(box_square, box_coord[2] - original_center[0]) + box_square = max(box_square, box_coord[3] - original_center[1]) + x, y, w, h = original_center[0] - box_square, original_center[1] - box_square, 2 * box_square, 2 * box_square + box_size = box_square * 2 + + for image in images: + if image.mode == "RGB": + image = image.convert("RGBA") + image_arr = np.array(image) + side_len = ( + int(box_size / image_frame_ratio) + if image_frame_ratio is not None + else in_w + ) + padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) + center = side_len // 2 + padded_image[ + center - box_size // 2 : center - box_size // 2 + box_size, + center - box_size // 2 : center - box_size // 2 + box_size, + ] = image_arr[x : x + w, y : y + h] + + rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS) + # rgba = image.resize((W, H), Image.LANCZOS) + rgba_arr = np.array(rgba) / 255.0 + rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) + image = (rgb * 255).astype(np.uint8) + + images_v0.append(image) + + base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 12 + processed_file = os.path.join(output_folder, f"{base_count:06d}_process_input.mp4") + imageio.mimwrite(processed_file, images_v0, fps=10) + return processed_file def sample_sv3d( image, @@ -212,26 +282,32 @@ def sample_sv3d( polar_rad: Optional[Union[float, List[float]]] = None, azim_rad: Optional[List[float]] = None, verbose: Optional[bool] = False, + sv3d_model=None, ): """ Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. """ - if version == "sv3d_u": - model_config = "scripts/sampling/configs/sv3d_u.yaml" - elif version == "sv3d_p": - model_config = "scripts/sampling/configs/sv3d_p.yaml" + if sv3d_model is None: + if version == "sv3d_u": + model_config = "scripts/sampling/configs/sv3d_u.yaml" + elif version == "sv3d_p": + model_config = "scripts/sampling/configs/sv3d_p.yaml" + else: + raise ValueError(f"Version {version} does not exist.") + + model, filter = load_model( + model_config, + device, + num_frames, + num_steps, + verbose, + ) else: - raise ValueError(f"Version {version} does not exist.") - - model, filter = load_model( - model_config, - device, - num_frames, - num_steps, - verbose, - ) + model = sv3d_model + + load_module_gpu(model) H, W = image.shape[2:] F = 8 @@ -286,23 +362,30 @@ def denoiser(input, sigma, c): ) samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) + unload_module_gpu(model.model) + unload_module_gpu(model.denoiser) model.en_and_decode_n_samples_a_time = decoding_t samples_x = model.decode_first_stage(samples_z) samples_x[-1:] = value_dict["cond_frames_without_noise"] samples = torch.clamp(samples_x, min=-1.0, max=1.0) - return samples + unload_module_gpu(model) + return samples -def decode_latents(model, samples_z, timesteps): +def decode_latents(model, samples_z, img_matrix, frame_indices, view_indices, timesteps): load_module_gpu(model.first_stage_model) - if isinstance(model.first_stage_model.decoder, VideoDecoder): - samples_x = model.decode_first_stage(samples_z, timesteps=timesteps) - else: - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + for t in frame_indices: + for v in view_indices: + if t != 0 and v != 0: + if isinstance(model.first_stage_model.decoder, VideoDecoder): + samples_x = model.decode_first_stage(samples_z[t, v][None], timesteps=timesteps) + else: + samples_x = model.decode_first_stage(samples_z[t, v][None]) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + img_matrix[t][v] = samples * 2 - 1 unload_module_gpu(model.first_stage_model) - return samples + return img_matrix def init_embedder_options_no_st(keys, init_dict, prompt=None, negative_prompt=None): @@ -604,6 +687,7 @@ def run_img2vid( azim_rad=np.linspace(0, 360, 21 + 1)[1:], cond_motion=None, cond_view=None, + decoding_t=None, ): options = version_dict["options"] H = version_dict["H"] @@ -670,12 +754,53 @@ def run_img2vid( force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None), force_cond_zero_embeddings=options.get("force_cond_zero_embeddings", None), return_latents=False, - decoding_t=options.get("decoding_T", T), + decoding_t=decoding_t, ) return samples +def prepare_inputs(frame_indices, img_matrix, v0, view_indices, model, version_dict, seed, polars, azims): + load_module_gpu(model.conditioner) + + forward_frame_indices = frame_indices.copy() + t0 = forward_frame_indices[0] + image = img_matrix[t0][v0] + cond_motion = torch.cat([img_matrix[t][v0] for t in forward_frame_indices], 0) + cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0) + forward_inputs = prepare_sampling( + version_dict, + model, + image, + seed, + polars, + azims, + cond_motion, + cond_view, + ) + + # backward sampling + backward_frame_indices = frame_indices[ + ::-1 + ].copy() + t0 = backward_frame_indices[0] + image = img_matrix[t0][v0] + cond_motion = torch.cat([img_matrix[t][v0] for t in backward_frame_indices], 0) + cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0) + backward_inputs = prepare_sampling( + version_dict, + model, + image, + seed, + polars, + azims, + cond_motion, + cond_view, + ) + + unload_module_gpu(model.conditioner) + return forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices + def do_sample( model, sampler, @@ -761,13 +886,11 @@ def denoiser(input, sigma, c): return model.denoiser( model.model, input, sigma, c, **additional_model_inputs ) - load_module_gpu(model.model) load_module_gpu(model.denoiser) samples_z = sampler(denoiser, randn, cond=c, uc=uc) unload_module_gpu(model.model) unload_module_gpu(model.denoiser) - load_module_gpu(model.first_stage_model) if isinstance(model.first_stage_model.decoder, VideoDecoder): samples_x = model.decode_first_stage( @@ -777,17 +900,15 @@ def denoiser(input, sigma, c): samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) unload_module_gpu(model.first_stage_model) - if filter is not None: samples = filter(samples) if return_latents: return samples, samples_z - return samples -def do_sample_per_step( +def prepare_sampling_( model, sampler, value_dict, @@ -797,8 +918,6 @@ def do_sample_per_step( batch2model_input: List = None, T=None, additional_batch_uc_fields=None, - step=None, - noisy_latents=None, ): force_uc_zero_embeddings = default(force_uc_zero_embeddings, []) batch2model_input = default(batch2model_input, []) @@ -812,8 +931,6 @@ def do_sample_per_step( num_samples = [num_samples, T] else: num_samples = [num_samples] - - load_module_gpu(model.conditioner) batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, @@ -827,8 +944,6 @@ def do_sample_per_step( force_uc_zero_embeddings=force_uc_zero_embeddings, force_cond_zero_embeddings=force_cond_zero_embeddings, ) - unload_module_gpu(model.conditioner) - for k in c: if not k == "crossattn": c[k], uc[k] = map( @@ -859,7 +974,14 @@ def do_sample_per_step( ) else: additional_model_inputs[k] = batch[k] + return c, uc, additional_model_inputs + +def do_sample_per_step(model, sampler, noisy_latents, c, uc, step, additional_model_inputs): + precision_scope = autocast + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): noisy_latents_scaled, s_in, sigmas, num_sigmas, _, _ = ( sampler.prepare_sampling_loop( noisy_latents.clone(), c, uc, sampler.num_steps @@ -893,13 +1015,10 @@ def denoiser(input, sigma, c): uc, gamma, ) - unload_module_gpu(model.model) - unload_module_gpu(model.denoiser) - return samples_z -def run_img2vid_per_step( +def prepare_sampling( version_dict, model, image, @@ -908,8 +1027,6 @@ def run_img2vid_per_step( azim_rad=np.linspace(0, 360, 21 + 1)[1:], cond_motion=None, cond_view=None, - step=None, - noisy_latents=None, ): options = version_dict["options"] H = version_dict["H"] @@ -962,7 +1079,7 @@ def run_img2vid_per_step( sampler, num_rows, num_cols = init_sampling_no_st(options=options) num_samples = num_rows * num_cols - samples = do_sample_per_step( + c, uc, additional_model_inputs = prepare_sampling_( model, sampler, value_dict, @@ -971,11 +1088,9 @@ def run_img2vid_per_step( force_cond_zero_embeddings=options.get("force_cond_zero_embeddings", None), batch2model_input=["num_video_frames", "image_only_indicator"], T=T, - step=step, - noisy_latents=noisy_latents, ) - return samples + return c, uc, additional_model_inputs, sampler def get_unique_embedder_keys_from_conditioner(conditioner): diff --git a/scripts/sampling/simple_video_sample_4d.py b/scripts/sampling/simple_video_sample_4d.py index c970e74f..704d7a29 100755 --- a/scripts/sampling/simple_video_sample_4d.py +++ b/scripts/sampling/simple_video_sample_4d.py @@ -10,15 +10,19 @@ import torch from fire import Fire +from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder from scripts.demo.sv4d_helpers import ( decode_latents, load_model, initial_model_load, read_video, run_img2vid, - run_img2vid_per_step, + prepare_sampling, + prepare_inputs, + do_sample_per_step, sample_sv3d, save_video, + preprocess_video, ) @@ -32,17 +36,18 @@ def sample( motion_bucket_id: int = 127, cond_aug: float = 1e-5, seed: int = 23, - decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. + encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary. + decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: str = "cuda", elevations_deg: Optional[Union[float, List[float]]] = 10.0, azimuths_deg: Optional[List[float]] = None, - image_frame_ratio: Optional[float] = None, + image_frame_ratio: Optional[float] = 0.917, verbose: Optional[bool] = False, remove_bg: bool = False, ): """ Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each - image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. + image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`. """ # Set model config T = 5 # number of frames per sample @@ -89,15 +94,16 @@ def sample( # Read input video frames i.e. images at view 0 print(f"Reading {input_path}") - images_v0 = read_video( + processed_input_path = preprocess_video( input_path, + remove_bg=remove_bg, n_frames=n_frames, W=W, H=H, - remove_bg=remove_bg, + output_folder=output_folder, image_frame_ratio=image_frame_ratio, - device=device, ) + images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device) # Get camera viewpoints if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): @@ -139,7 +145,7 @@ def sample( for t in range(n_frames): img_matrix[t][0] = images_v0[t] - base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 11 + base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 12 save_video( os.path.join(output_folder, f"{base_count:06d}_t000.mp4"), img_matrix[0], @@ -158,6 +164,10 @@ def sample( verbose, ) model = initial_model_load(model) + for emb in model.conditioner.embedders: + if isinstance(emb, VideoPredictionEmbedderWithEncoder): + emb.en_and_decode_n_samples_a_time = encoding_t + model.en_and_decode_n_samples_a_time = decoding_t # Interleaved sampling for anchor frames t0, v0 = 0, 0 @@ -171,7 +181,7 @@ def sample( azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() azims = (azims - azimuths_rad[v0]) % (torch.pi * 2) samples = run_img2vid( - version_dict, model, image, seed, polars, azims, cond_motion, cond_view + version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t ) samples = samples.view(T, V, 3, H, W) for i, t in enumerate(frame_indices): @@ -185,40 +195,48 @@ def sample( frame_indices = t0 + np.arange(T) print(f"Sampling dense frames {frame_indices}") latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda") + + polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() + azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() + azims = (azims - azimuths_rad[v0]) % (torch.pi * 2) + + # alternate between forward and backward conditioning + forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs( + frame_indices, + img_matrix, + v0, + view_indices, + model, + version_dict, + seed, + polars, + azims + ) + for step in tqdm(range(num_steps)): - frame_indices = frame_indices[ - ::-1 - ].copy() # alternate between forward and backward conditioning - t0 = frame_indices[0] - image = img_matrix[t0][v0] - cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0) - cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0) - polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() - azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten() - azims = (azims - azimuths_rad[v0]) % (torch.pi * 2) + if step % 2 == 1: + c, uc, additional_model_inputs, sampler = forward_inputs + frame_indices = forward_frame_indices + else: + c, uc, additional_model_inputs, sampler = backward_inputs + frame_indices = backward_frame_indices noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1) - samples = run_img2vid_per_step( - version_dict, + + samples = do_sample_per_step( model, - image, - seed, - polars, - azims, - cond_motion, - cond_view, - step, + sampler, noisy_latents, + c, + uc, + step, + additional_model_inputs, ) samples = samples.view(T, V, C, H // F, W // F) for i, t in enumerate(frame_indices): for j, v in enumerate(view_indices): latent_matrix[t, v] = samples[i, j] - for t in frame_indices: - for v in view_indices: - if t != 0 and v != 0: - img = decode_latents(model, latent_matrix[t, v][None], T) - img_matrix[t][v] = img * 2 - 1 + img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T) # Save output videos for v in view_indices: