diff --git a/README.md b/README.md index 65440536..dd9d61d8 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,24 @@ ![sample1](assets/000.jpg) ## News + +**March 18, 2024** +- We are releasing [SV3D](https://huggingface.co/stabilityai/sv3d), an image-to-video model for novel multi-view synthesis, for research purposes: + - SV3D was trained to generate 21 frames at resolution 576x576, given 1 context frame of the same size, ideally a white-background image with one object. + - SV3D_u: This variant generates orbital videos based on single image inputs without camera conditioning.. + - SV3D_p: Extending the capability of SVD3_u, this variant accommodates both single images and orbital views allowing for the creation of 3D video along specified camera paths. + - We extend the streamlit demo `scripts/demo/video_sampling.py` and the standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models. + - Please check our [project page](https://sv3d.github.io), [tech report](https://sv3d.github.io/static/paper.pdf) and [video summary](https://youtu.be/Zqw4-1LcfWg) for more details. + +To run SV3D on a single image: +`python scripts/sampling/simple_video_sample.py --input_path --version sv3d_p` + +To run SVD or SV3D on a streamlit server: +`streamlit run scripts/demo/video_sampling.py` + + ![tile](assets/sv3d.gif) + + **November 30, 2023** - Following the launch of SDXL-Turbo, we are releasing [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo). @@ -24,7 +42,7 @@ We use the standard image encoder from SD 2.1, but replace the decoder with a temporally-aware `deflickering decoder`. - [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt): Same architecture as `SVD` but finetuned for 25 frame generation. - - You can run the community-build gradio demo locally by running `python -m scripts.demo.gradio_app`. + - You can run the community-build gradio demo locally by running `python -m scripts.demo.gradio_app`. - We provide a streamlit demo `scripts/demo/video_sampling.py` and a standalone python script `scripts/sampling/simple_video_sample.py` for inference of both models. - Alongside the model, we release a [technical report](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets). diff --git a/assets/sv3d.gif b/assets/sv3d.gif new file mode 100644 index 00000000..7503ad75 Binary files /dev/null and b/assets/sv3d.gif differ diff --git a/configs/inference/sv3d_p.yaml b/configs/inference/sv3d_p.yaml new file mode 100644 index 00000000..d3781fe5 --- /dev/null +++ b/configs/inference/sv3d_p.yaml @@ -0,0 +1,118 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.18215 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise + + network_config: + target: sgm.modules.diffusionmodules.video_model.VideoUNet + params: + adm_in_channels: 1280 + num_classes: sequential + use_checkpoint: True + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + spatial_transformer_attn_type: softmax-xformers + extra_ff_mix_layer: True + use_spatial_context: True + merge_strategy: learned_with_images + video_kernel_size: [3, 1, 1] + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - input_key: cond_frames_without_noise + is_trainable: False + target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder + params: + n_cond_frames: 1 + n_copies: 1 + open_clip_embedding_config: + target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder + params: + freeze: True + + - input_key: cond_frames + is_trainable: False + target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + params: + disable_encoder_autocast: True + n_cond_frames: 1 + n_copies: 1 + is_ae: True + encoder_config: + target: sgm.models.autoencoder.AutoencoderKLModeOnly + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + - input_key: cond_aug + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - input_key: polars_rad + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 512 + + - input_key: azimuths_rad + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 512 + + first_stage_config: + target: sgm.models.autoencoder.AutoencodingEngine + params: + loss_config: + target: torch.nn.Identity + regularizer_config: + target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + encoder_config: + target: torch.nn.Identity + decoder_config: + target: sgm.modules.diffusionmodules.model.Decoder + params: + attn_type: vanilla-xformers + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 4, 4 ] + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 \ No newline at end of file diff --git a/configs/inference/sv3d_u.yaml b/configs/inference/sv3d_u.yaml new file mode 100644 index 00000000..5c48a5ff --- /dev/null +++ b/configs/inference/sv3d_u.yaml @@ -0,0 +1,106 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.18215 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise + + network_config: + target: sgm.modules.diffusionmodules.video_model.VideoUNet + params: + adm_in_channels: 256 + num_classes: sequential + use_checkpoint: True + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + spatial_transformer_attn_type: softmax-xformers + extra_ff_mix_layer: True + use_spatial_context: True + merge_strategy: learned_with_images + video_kernel_size: [3, 1, 1] + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - input_key: cond_frames_without_noise + is_trainable: False + target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder + params: + n_cond_frames: 1 + n_copies: 1 + open_clip_embedding_config: + target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder + params: + freeze: True + + - input_key: cond_frames + is_trainable: False + target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + params: + disable_encoder_autocast: True + n_cond_frames: 1 + n_copies: 1 + is_ae: True + encoder_config: + target: sgm.models.autoencoder.AutoencoderKLModeOnly + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + - input_key: cond_aug + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + first_stage_config: + target: sgm.models.autoencoder.AutoencodingEngine + params: + loss_config: + target: torch.nn.Identity + regularizer_config: + target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + encoder_config: + target: torch.nn.Identity + decoder_config: + target: sgm.modules.diffusionmodules.model.Decoder + params: + attn_type: vanilla-xformers + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 4, 4 ] + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 \ No newline at end of file diff --git a/model_licenses/LICENSE-SV3D b/model_licenses/LICENSE-SV3D new file mode 100644 index 00000000..2a9ddf37 --- /dev/null +++ b/model_licenses/LICENSE-SV3D @@ -0,0 +1,41 @@ +STABILITY AI NON-COMMERCIAL COMMUNITY LICENSE AGREEMENT +Dated: March 18, 2024 + +"Agreement" means this Stable Non-Commercial Research Community License Agreement. + +“AUP” means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may be updated from time to time. + +"Derivative Work(s)” means (a) any derivative work of the Software Products as recognized by U.S. copyright laws, (b) any modifications to a Model, and (c) any other model created which is based on or derived from the Model or the Model’s output. For clarity, Derivative Works do not include the output of any Model. + +“Documentation” means any specifications, manuals, documentation, and other written information provided by Stability AI related to the Software. + +"Licensee" or "you" means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. + +“Model(s)" means, collectively, Stability AI’s proprietary models and algorithms, including machine-learning models, trained model weights and other elements of the foregoing, made available under this Agreement. + +“Non-Commercial Uses” means exercising any of the rights granted herein for the purpose of research or non-commercial purposes. Non-Commercial Uses does not include any production use of the Software Products or any Derivative Works. + +"Stability AI" or "we" means Stability AI Ltd and its affiliates. + + +"Software" means Stability AI’s proprietary software made available under this Agreement. + +“Software Products” means the Models, Software and Documentation, individually or in any combination. + + + +1. License Rights and Redistribution. +a. Subject to your compliance with this Agreement, the AUP (which is hereby incorporated herein by reference), and the Documentation, Stability AI grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license under Stability AI’s intellectual property or other rights owned or controlled by Stability AI embodied in the Software Products to use, reproduce, distribute, and create Derivative Works of, the Software Products, in each case for Non-Commercial Uses only. +b. You may not use the Software Products or Derivative Works to enable third parties to use the Software Products or Derivative Works as part of your hosted service or via your APIs, whether you are adding substantial additional functionality thereto or not. Merely distributing the Software Products or Derivative Works for download online without offering any related service (ex. by distributing the Models on HuggingFace) is not a violation of this subsection. If you wish to use the Software Products or any Derivative Works for commercial or production use or you wish to make the Software Products or any Derivative Works available to third parties via your hosted service or your APIs, contact Stability AI at https://stability.ai/contact. +c. If you distribute or make the Software Products, or any Derivative Works thereof, available to a third party, the Software Products, Derivative Works, or any portion thereof, respectively, will remain subject to this Agreement and you must (i) provide a copy of this Agreement to such third party, and (ii) retain the following attribution notice within a "Notice" text file distributed as a part of such copies: "This Stability AI Model is licensed under the Stability AI Non-Commercial Research Community License, Copyright (c) Stability AI Ltd. All Rights Reserved.” If you create a Derivative Work of a Software Product, you may add your own attribution notices to the Notice file included with the Software Product, provided that you clearly indicate which attributions apply to the Software Product and you must state in the NOTICE file that you changed the Software Product and how it was modified. +2. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS, DERIVATIVE WORKS OR ANY OUTPUT OR RESULTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS, DERIVATIVE WORKS AND ANY OUTPUT AND RESULTS. +3. Limitation of Liability. IN NO EVENT WILL STABILITY AI OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF STABILITY AI OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. +4. Intellectual Property. +a. No trademark licenses are granted under this Agreement, and in connection with the Software Products or Derivative Works, neither Stability AI nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Software Products or Derivative Works. +b. Subject to Stability AI’s ownership of the Software Products and Derivative Works made by or for Stability AI, with respect to any Derivative Works that are made by you, as between you and Stability AI, you are and will be the owner of such Derivative Works +c. If you institute litigation or other proceedings against Stability AI (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, Derivative Works or associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out of or related to your use or distribution of the Software Products or Derivative Works in violation of this Agreement. +5. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Software Products and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Stability AI may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of any Software Products or Derivative Works. Sections 2-4 shall survive the termination of this Agreement. + +6. Governing Law. This Agreement will be governed by and construed in accordance with the laws of the United States and the State of California without regard to choice of law +principles. + diff --git a/model_licenses/LICENSE-SDV b/model_licenses/LICENSE-SVD similarity index 100% rename from model_licenses/LICENSE-SDV rename to model_licenses/LICENSE-SVD diff --git a/requirements/pt2.txt b/requirements/pt2.txt index 26bb71a6..824473ab 100644 --- a/requirements/pt2.txt +++ b/requirements/pt2.txt @@ -19,6 +19,7 @@ pillow>=9.5.0 pudb>=2022.1.3 pytorch-lightning==2.0.1 pyyaml>=6.0.1 +rembg scipy>=1.10.1 streamlit>=0.73.1 tensorboardx==2.6 diff --git a/scripts/demo/gradio_app.py b/scripts/demo/gradio_app.py index ab7c9d30..ed6d4877 100644 --- a/scripts/demo/gradio_app.py +++ b/scripts/demo/gradio_app.py @@ -23,9 +23,11 @@ from torchvision.transforms import ToTensor from scripts.sampling.simple_video_sample import ( - get_batch, get_unique_embedder_keys_from_conditioner, load_model) -from scripts.util.detection.nsfw_and_watermark_dectection import \ - DeepFloydDataFiltering + get_batch, + get_unique_embedder_keys_from_conditioner, + load_model, +) +from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.inference.helpers import embed_watermark from sgm.util import default, instantiate_from_config diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 6c5760e2..e79fc193 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -5,6 +5,7 @@ from typing import Dict, List, Optional, Tuple, Union import cv2 +import imageio import numpy as np import streamlit as st import torch @@ -15,25 +16,30 @@ from omegaconf import ListConfig, OmegaConf from PIL import Image from safetensors.torch import load_file as load_safetensors +from scripts.demo.discretization import ( + Img2ImgDiscretizationWrapper, + Txt2NoisyDiscretizationWrapper, +) +from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering +from sgm.inference.helpers import embed_watermark +from sgm.modules.diffusionmodules.guiders import ( + LinearPredictionGuider, + TrianglePredictionGuider, + VanillaCFG, +) +from sgm.modules.diffusionmodules.sampling import ( + DPMPP2MSampler, + DPMPP2SAncestralSampler, + EulerAncestralSampler, + EulerEDMSampler, + HeunEDMSampler, + LinearMultistepSampler, +) +from sgm.util import append_dims, default, instantiate_from_config from torch import autocast from torchvision import transforms from torchvision.utils import make_grid, save_image -from scripts.demo.discretization import (Img2ImgDiscretizationWrapper, - Txt2NoisyDiscretizationWrapper) -from scripts.util.detection.nsfw_and_watermark_dectection import \ - DeepFloydDataFiltering -from sgm.inference.helpers import embed_watermark -from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider, - VanillaCFG) -from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler, - DPMPP2SAncestralSampler, - EulerAncestralSampler, - EulerEDMSampler, - HeunEDMSampler, - LinearMultistepSampler) -from sgm.util import append_dims, default, instantiate_from_config - @st.cache_resource() def init_st(version_dict, load_ckpt=True, load_filter=True): @@ -222,6 +228,7 @@ def get_guider(options, key): "VanillaCFG", "IdentityGuider", "LinearPredictionGuider", + "TrianglePredictionGuider", ], options.get("guider", 0), ) @@ -252,7 +259,7 @@ def get_guider(options, key): value=options.get("cfg", 1.5), min_value=1.0, ) - min_scale = st.number_input( + min_scale = st.sidebar.number_input( f"min guidance scale", value=options.get("min_cfg", 1.0), min_value=1.0, @@ -268,6 +275,29 @@ def get_guider(options, key): **additional_guider_kwargs, }, } + elif guider == "TrianglePredictionGuider": + max_scale = st.number_input( + f"max-cfg-scale #{key}", + value=options.get("cfg", 2.5), + min_value=1.0, + max_value=10.0, + ) + min_scale = st.sidebar.number_input( + f"min guidance scale", + value=options.get("min_cfg", 1.0), + min_value=1.0, + max_value=10.0, + ) + + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider", + "params": { + "max_scale": max_scale, + "min_scale": min_scale, + "num_frames": options["num_frames"], + **additional_guider_kwargs, + }, + } else: raise NotImplementedError return guider_config @@ -288,8 +318,8 @@ def init_sampling( f"num cols #{key}", value=num_cols, min_value=1, max_value=10 ) - steps = st.sidebar.number_input( - f"steps #{key}", value=options.get("num_steps", 40), min_value=1, max_value=1000 + steps = st.number_input( + f"steps #{key}", value=options.get("num_steps", 50), min_value=1, max_value=1000 ) sampler = st.sidebar.selectbox( f"Sampler #{key}", @@ -337,13 +367,13 @@ def get_discretization(discretization, options, key=1): "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", } elif discretization == "EDMDiscretization": - sigma_min = st.number_input( + sigma_min = st.sidebar.number_input( f"sigma_min #{key}", value=options.get("sigma_min", 0.03) ) # 0.0292 - sigma_max = st.number_input( + sigma_max = st.sidebar.number_input( f"sigma_max #{key}", value=options.get("sigma_max", 14.61) ) # 14.6146 - rho = st.number_input(f"rho #{key}", value=options.get("rho", 3.0)) + rho = st.sidebar.number_input(f"rho #{key}", value=options.get("rho", 3.0)) discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", "params": { @@ -542,7 +572,12 @@ def do_sample( assert T is not None if isinstance( - sampler.guider, (VanillaCFG, LinearPredictionGuider) + sampler.guider, + ( + VanillaCFG, + LinearPredictionGuider, + TrianglePredictionGuider, + ), ): additional_model_inputs[k] = torch.zeros( num_samples[0] * 2, num_samples[1] @@ -678,6 +713,12 @@ def get_batch( batch[key] = repeat( value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] ) + elif key == "polars_rad": + batch[key] = torch.tensor(value_dict["polars_rad"]).to(device).repeat(N[0]) + elif key == "azimuths_rad": + batch[key] = ( + torch.tensor(value_dict["azimuths_rad"]).to(device).repeat(N[0]) + ) else: batch[key] = value_dict[key] @@ -827,8 +868,13 @@ def load_img_for_prediction( st.image(image) w, h = image.size - image = np.array(image).transpose(2, 0, 1) - image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0 + image = np.array(image).astype(np.float32) / 255 + if image.shape[-1] == 4: + rgb, alpha = image[:, :, :3], image[:, :, 3:] + image = rgb * alpha + (1 - alpha) + + image = image.transpose(2, 0, 1) + image = torch.from_numpy(image).to(dtype=torch.float32) image = image.unsqueeze(0) rfs = get_resizing_factor((H, W), (h, w)) @@ -860,28 +906,16 @@ def save_video_as_grid_and_mp4( save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4) video_path = os.path.join(save_path, f"{base_count:06d}.mp4") - - writer = cv2.VideoWriter( - video_path, - cv2.VideoWriter_fourcc(*"MP4V"), - fps, - (vid.shape[-1], vid.shape[-2]), - ) - vid = ( (rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8) ) - for frame in vid: - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) - writer.write(frame) - - writer.release() + imageio.mimwrite(video_path, vid, fps=fps) video_path_h264 = video_path[:-4] + "_h264.mp4" - os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}") - + os.system(f"ffmpeg -i '{video_path}' -c:v libx264 '{video_path_h264}'") with open(video_path_h264, "rb") as f: video_bytes = f.read() + os.remove(video_path_h264) st.video(video_bytes) base_count += 1 diff --git a/scripts/demo/sv3d_helpers.py b/scripts/demo/sv3d_helpers.py new file mode 100644 index 00000000..a0cebd19 --- /dev/null +++ b/scripts/demo/sv3d_helpers.py @@ -0,0 +1,104 @@ +import os + +import matplotlib.pyplot as plt +import numpy as np + + +def generate_dynamic_cycle_xy_values( + length=21, + init_elev=0, + num_components=84, + frequency_range=(1, 5), + amplitude_range=(0.5, 10), + step_range=(0, 2), +): + # Y values generation + y_sequence = np.ones(length) * init_elev + for _ in range(num_components): + # Choose a frequency that will complete whole cycles in the sequence + frequency = np.random.randint(*frequency_range) * (2 * np.pi / length) + amplitude = np.random.uniform(*amplitude_range) + phase_shift = np.random.choice([0, np.pi]) # np.random.uniform(0, 2 * np.pi) + angles = ( + np.linspace(0, frequency * length, length, endpoint=False) + phase_shift + ) + y_sequence += np.sin(angles) * amplitude + # X values generation + # Generate length - 1 steps since the last step is back to start + steps = np.random.uniform(*step_range, length - 1) + total_step_sum = np.sum(steps) + # Calculate the scale factor to scale total steps to just under 360 + scale_factor = ( + 360 - ((360 / length) * np.random.uniform(*step_range)) + ) / total_step_sum + # Apply the scale factor and generate the sequence of X values + x_values = np.cumsum(steps * scale_factor) + # Ensure the sequence starts at 0 and add the final step to complete the loop + x_values = np.insert(x_values, 0, 0) + return x_values, y_sequence + + +def smooth_data(data, window_size): + # Extend data at both ends by wrapping around to create a continuous loop + pad_size = window_size + padded_data = np.concatenate((data[-pad_size:], data, data[:pad_size])) + + # Apply smoothing + kernel = np.ones(window_size) / window_size + smoothed_data = np.convolve(padded_data, kernel, mode="same") + + # Extract the smoothed data corresponding to the original sequence + # Adjust the indices to account for the larger padding + start_index = pad_size + end_index = -pad_size if pad_size != 0 else None + smoothed_original_data = smoothed_data[start_index:end_index] + return smoothed_original_data + + +# Function to generate and process the data +def gen_dynamic_loop(length=21, elev_deg=0): + while True: + # Generate the combined X and Y values using the new function + azim_values, elev_values = generate_dynamic_cycle_xy_values( + length=84, init_elev=elev_deg + ) + # Smooth the Y values directly + smoothed_elev_values = smooth_data(elev_values, 5) + max_magnitude = np.max(np.abs(smoothed_elev_values)) + if max_magnitude < 90: + break + subsample = 84 // length + azim_rad = np.deg2rad(azim_values[::subsample]) + elev_rad = np.deg2rad(smoothed_elev_values[::subsample]) + # Make cond frame the last one + return np.roll(azim_rad, -1), np.roll(elev_rad, -1) + + +def plot_3D(azim, polar, save_path, dynamic=True): + os.makedirs(os.path.dirname(save_path), exist_ok=True) + elev = np.deg2rad(90) - polar + fig = plt.figure(figsize=(5, 5)) + ax = fig.add_subplot(projection="3d") + cm = plt.get_cmap("Greys") + col_line = [cm(i) for i in np.linspace(0.3, 1, len(azim) + 1)] + cm = plt.get_cmap("cool") + col = [cm(float(i) / (len(azim))) for i in np.arange(len(azim))] + xs = np.cos(elev) * np.cos(azim) + ys = np.cos(elev) * np.sin(azim) + zs = np.sin(elev) + ax.scatter(xs[0], ys[0], zs[0], s=100, color=col[0]) + xs_d, ys_d, zs_d = (xs[1:] - xs[:-1]), (ys[1:] - ys[:-1]), (zs[1:] - zs[:-1]) + for i in range(len(xs) - 1): + if dynamic: + ax.quiver( + xs[i], ys[i], zs[i], xs_d[i], ys_d[i], zs_d[i], lw=2, color=col_line[i] + ) + else: + ax.plot(xs[i : i + 2], ys[i : i + 2], zs[i : i + 2], lw=2, c=col_line[i]) + ax.scatter(xs[i + 1], ys[i + 1], zs[i + 1], s=100, color=col[i + 1]) + ax.scatter(xs[:1], ys[:1], zs[:1], s=120, facecolors="none", edgecolors="k") + ax.scatter(xs[-1:], ys[-1:], zs[-1:], s=120, facecolors="none", edgecolors="k") + ax.view_init(elev=30, azim=-20, roll=0) + plt.savefig(save_path, bbox_inches="tight") + plt.clf() + plt.close() diff --git a/scripts/demo/video_sampling.py b/scripts/demo/video_sampling.py index 95789020..1f4fcfc4 100644 --- a/scripts/demo/video_sampling.py +++ b/scripts/demo/video_sampling.py @@ -1,8 +1,10 @@ import os +import sys +sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../"))) from pytorch_lightning import seed_everything - from scripts.demo.streamlit_helpers import * +from scripts.demo.sv3d_helpers import * SAVE_PATH = "outputs/demo/vid/" @@ -87,11 +89,51 @@ "decoding_t": 14, }, }, + "sv3d_u": { + "T": 21, + "H": 576, + "W": 576, + "C": 4, + "f": 8, + "config": "configs/inference/sv3d_u.yaml", + "ckpt": "checkpoints/sv3d_u.safetensors", + "options": { + "discretization": 1, + "cfg": 2.5, + "sigma_min": 0.002, + "sigma_max": 700.0, + "rho": 7.0, + "guider": 3, + "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], + "num_steps": 50, + "decoding_t": 14, + }, + }, + "sv3d_p": { + "T": 21, + "H": 576, + "W": 576, + "C": 4, + "f": 8, + "config": "configs/inference/sv3d_p.yaml", + "ckpt": "checkpoints/sv3d_p.safetensors", + "options": { + "discretization": 1, + "cfg": 2.5, + "sigma_min": 0.002, + "sigma_max": 700.0, + "rho": 7.0, + "guider": 3, + "force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"], + "num_steps": 50, + "decoding_t": 14, + }, + }, } if __name__ == "__main__": - st.title("Stable Video Diffusion") + st.title("Stable Video Diffusion / SV3D") version = st.selectbox( "Model Version", [k for k in VERSION2SPECS.keys()], @@ -131,17 +173,42 @@ {}, ) + if "fps" not in ukeys: + value_dict["fps"] = 10 + value_dict["image_only_indicator"] = 0 if mode == "img2vid": img = load_img_for_prediction(W, H) - cond_aug = st.number_input( - "Conditioning augmentation:", value=0.02, min_value=0.0 - ) + if "sv3d" in version: + cond_aug = 1e-5 + else: + cond_aug = st.number_input( + "Conditioning augmentation:", value=0.02, min_value=0.0 + ) value_dict["cond_frames_without_noise"] = img value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img) value_dict["cond_aug"] = cond_aug + if "sv3d_p" in version: + elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90) + trajectory = st.selectbox( + "Trajectory", + ["same elevation", "dynamic"], + 0, + ) + if trajectory == "same elevation": + value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T) + value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:] + elif trajectory == "dynamic": + azim_rad, elev_rad = gen_dynamic_loop(length=21, elev_deg=elev_deg) + value_dict["polars_rad"] = np.deg2rad(90) - elev_rad + value_dict["azimuths_rad"] = azim_rad + elif "sv3d_u" in version: + elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90) + value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T) + value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:] + seed = st.sidebar.number_input( "seed", value=23, min_value=0, max_value=int(1e9) ) @@ -151,6 +218,19 @@ os.path.join(SAVE_PATH, version), init_value=True ) + if "sv3d" in version: + plot_save_path = os.path.join(save_path, "plot_3D.png") + plot_3D( + azim=value_dict["azimuths_rad"], + polar=value_dict["polars_rad"], + save_path=plot_save_path, + dynamic=("sv3d_p" in version), + ) + st.image( + plot_save_path, + f"3D camera trajectory", + ) + options["num_frames"] = T sampler, num_rows, num_cols = init_sampling(options=options) diff --git a/scripts/sampling/configs/sv3d_p.yaml b/scripts/sampling/configs/sv3d_p.yaml new file mode 100644 index 00000000..bb3747c7 --- /dev/null +++ b/scripts/sampling/configs/sv3d_p.yaml @@ -0,0 +1,132 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.18215 + disable_first_stage_autocast: True + ckpt_path: checkpoints/sv3d_p_image_decoder.safetensors + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise + + network_config: + target: sgm.modules.diffusionmodules.video_model.VideoUNet + params: + adm_in_channels: 1280 + num_classes: sequential + use_checkpoint: True + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + spatial_transformer_attn_type: softmax-xformers + extra_ff_mix_layer: True + use_spatial_context: True + merge_strategy: learned_with_images + video_kernel_size: [3, 1, 1] + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - input_key: cond_frames_without_noise + is_trainable: False + target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder + params: + n_cond_frames: 1 + n_copies: 1 + open_clip_embedding_config: + target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder + params: + freeze: True + + - input_key: cond_frames + is_trainable: False + target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + params: + disable_encoder_autocast: True + n_cond_frames: 1 + n_copies: 1 + is_ae: True + encoder_config: + target: sgm.models.autoencoder.AutoencoderKLModeOnly + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + - input_key: cond_aug + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - input_key: polars_rad + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 512 + + - input_key: azimuths_rad + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 512 + + first_stage_config: + target: sgm.models.autoencoder.AutoencodingEngine + params: + loss_config: + target: torch.nn.Identity + regularizer_config: + target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + encoder_config: + target: torch.nn.Identity + decoder_config: + target: sgm.modules.diffusionmodules.model.Decoder + params: + attn_type: vanilla-xformers + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 4, 4 ] + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization + params: + sigma_max: 700.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider + params: + max_scale: 2.5 diff --git a/scripts/sampling/configs/sv3d_u.yaml b/scripts/sampling/configs/sv3d_u.yaml new file mode 100644 index 00000000..8a7ce212 --- /dev/null +++ b/scripts/sampling/configs/sv3d_u.yaml @@ -0,0 +1,120 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.18215 + disable_first_stage_autocast: True + ckpt_path: checkpoints/sv3d_u_image_decoder.safetensors + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise + + network_config: + target: sgm.modules.diffusionmodules.video_model.VideoUNet + params: + adm_in_channels: 256 + num_classes: sequential + use_checkpoint: True + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + spatial_transformer_attn_type: softmax-xformers + extra_ff_mix_layer: True + use_spatial_context: True + merge_strategy: learned_with_images + video_kernel_size: [3, 1, 1] + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: False + input_key: cond_frames_without_noise + target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder + params: + n_cond_frames: 1 + n_copies: 1 + open_clip_embedding_config: + target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder + params: + freeze: True + + - input_key: cond_frames + is_trainable: False + target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + params: + disable_encoder_autocast: True + n_cond_frames: 1 + n_copies: 1 + is_ae: True + encoder_config: + target: sgm.models.autoencoder.AutoencoderKLModeOnly + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + - input_key: cond_aug + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + first_stage_config: + target: sgm.models.autoencoder.AutoencodingEngine + params: + loss_config: + target: torch.nn.Identity + regularizer_config: + target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + encoder_config: + target: torch.nn.Identity + decoder_config: + target: sgm.modules.diffusionmodules.model.Decoder + params: + attn_type: vanilla-xformers + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 4, 4 ] + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization + params: + sigma_max: 700.0 + + guider_config: + target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider + params: + max_scale: 2.5 diff --git a/scripts/sampling/simple_video_sample.py b/scripts/sampling/simple_video_sample.py index c3f4ad2a..6d34e7cd 100644 --- a/scripts/sampling/simple_video_sample.py +++ b/scripts/sampling/simple_video_sample.py @@ -1,27 +1,29 @@ import math import os +import sys from glob import glob from pathlib import Path -from typing import Optional +from typing import List, Optional +sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../"))) import cv2 +import imageio import numpy as np import torch from einops import rearrange, repeat from fire import Fire from omegaconf import OmegaConf from PIL import Image -from torchvision.transforms import ToTensor - -from scripts.util.detection.nsfw_and_watermark_dectection import \ - DeepFloydDataFiltering +from rembg import remove +from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.inference.helpers import embed_watermark from sgm.util import default, instantiate_from_config +from torchvision.transforms import ToTensor def sample( input_path: str = "assets/test_image.png", # Can either be image file or folder with image files - num_frames: Optional[int] = None, + num_frames: Optional[int] = None, # 21 for SV3D num_steps: Optional[int] = None, version: str = "svd", fps_id: int = 6, @@ -31,6 +33,10 @@ def sample( decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. device: str = "cuda", output_folder: Optional[str] = None, + elevations_deg: Optional[float | List[float]] = 10.0, # For SV3D + azimuths_deg: Optional[float | List[float]] = None, # For SV3D + image_frame_ratio: Optional[float] = None, + verbose: Optional[bool] = False, ): """ Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each @@ -61,6 +67,24 @@ def sample( output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/" ) model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml" + elif version == "sv3d_u": + num_frames = 21 + num_steps = default(num_steps, 50) + output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_u/") + model_config = "scripts/sampling/configs/sv3d_u.yaml" + cond_aug = 1e-5 + elif version == "sv3d_p": + num_frames = 21 + num_steps = default(num_steps, 50) + output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_p/") + model_config = "scripts/sampling/configs/sv3d_p.yaml" + cond_aug = 1e-5 + if isinstance(elevations_deg, float) or isinstance(elevations_deg, int): + elevations_deg = [elevations_deg] * num_frames + polars_rad = [np.deg2rad(90 - e) for e in elevations_deg] + if azimuths_deg is None: + azimuths_deg = np.linspace(0, 360, num_frames + 1)[1:] % 360 + azimuths_rad = [np.deg2rad(a) for a in azimuths_deg] else: raise ValueError(f"Version {version} does not exist.") @@ -69,6 +93,7 @@ def sample( device, num_frames, num_steps, + verbose, ) torch.manual_seed(seed) @@ -93,20 +118,56 @@ def sample( raise ValueError for input_img_path in all_img_paths: - with Image.open(input_img_path) as image: + if "sv3d" in version: + image = Image.open(input_img_path) if image.mode == "RGBA": - image = image.convert("RGB") - w, h = image.size - - if h % 64 != 0 or w % 64 != 0: - width, height = map(lambda x: x - x % 64, (w, h)) - image = image.resize((width, height)) - print( - f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!" - ) + pass + else: + # remove bg + image.thumbnail([768, 768], Image.Resampling.LANCZOS) + image = remove(image.convert("RGBA"), alpha_matting=True) + + # resize object in frame + image_arr = np.array(image) + in_w, in_h = image_arr.shape[:2] + ret, mask = cv2.threshold( + np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY + ) + x, y, w, h = cv2.boundingRect(mask) + max_size = max(w, h) + side_len = ( + int(max_size / image_frame_ratio) + if image_frame_ratio is not None + else in_w + ) + padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) + center = side_len // 2 + padded_image[ + center - h // 2 : center - h // 2 + h, + center - w // 2 : center - w // 2 + w, + ] = image_arr[y : y + h, x : x + w] + # resize frame to 576x576 + rgba = Image.fromarray(padded_image).resize((576, 576), Image.LANCZOS) + # white bg + rgba_arr = np.array(rgba) / 255.0 + rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) + input_image = Image.fromarray((rgb * 255).astype(np.uint8)) - image = ToTensor()(image) - image = image * 2.0 - 1.0 + else: + with Image.open(input_img_path) as image: + if image.mode == "RGBA": + input_image = image.convert("RGB") + w, h = image.size + + if h % 64 != 0 or w % 64 != 0: + width, height = map(lambda x: x - x % 64, (w, h)) + input_image = input_image.resize((width, height)) + print( + f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!" + ) + + image = ToTensor()(input_image) + image = image * 2.0 - 1.0 image = image.unsqueeze(0).to(device) H, W = image.shape[2:] @@ -114,10 +175,14 @@ def sample( F = 8 C = 4 shape = (num_frames, C, H // F, W // F) - if (H, W) != (576, 1024): + if (H, W) != (576, 1024) and "sv3d" not in version: print( "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`." ) + if (H, W) != (576, 576) and "sv3d" in version: + print( + "WARNING: The conditioning frame you provided is not 576x576. This leads to suboptimal performance as model was only trained on 576x576." + ) if motion_bucket_id > 255: print( "WARNING: High motion bucket! This may lead to suboptimal performance." @@ -130,12 +195,14 @@ def sample( print("WARNING: Large fps value! This may lead to suboptimal performance.") value_dict = {} + value_dict["cond_frames_without_noise"] = image value_dict["motion_bucket_id"] = motion_bucket_id value_dict["fps_id"] = fps_id value_dict["cond_aug"] = cond_aug - value_dict["cond_frames_without_noise"] = image value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image) - value_dict["cond_aug"] = cond_aug + if "sv3d_p" in version: + value_dict["polars_rad"] = polars_rad + value_dict["azimuths_rad"] = azimuths_rad with torch.no_grad(): with torch.autocast(device): @@ -177,16 +244,15 @@ def denoiser(input, sigma, c): samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) model.en_and_decode_n_samples_a_time = decoding_t samples_x = model.decode_first_stage(samples_z) + if "sv3d" in version: + samples_x[-1:] = value_dict["cond_frames_without_noise"] samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) os.makedirs(output_folder, exist_ok=True) base_count = len(glob(os.path.join(output_folder, "*.mp4"))) - video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") - writer = cv2.VideoWriter( - video_path, - cv2.VideoWriter_fourcc(*"MP4V"), - fps_id + 1, - (samples.shape[-1], samples.shape[-2]), + + imageio.imwrite( + os.path.join(output_folder, f"{base_count:06d}.jpg"), input_image ) samples = embed_watermark(samples) @@ -197,10 +263,8 @@ def denoiser(input, sigma, c): .numpy() .astype(np.uint8) ) - for frame in vid: - frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) - writer.write(frame) - writer.release() + video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") + imageio.mimwrite(video_path, vid) def get_unique_embedder_keys_from_conditioner(conditioner): @@ -230,12 +294,10 @@ def get_batch(keys, value_dict, N, T, device): "1 -> b", b=math.prod(N), ) - elif key == "cond_frames": - batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) - elif key == "cond_frames_without_noise": - batch[key] = repeat( - value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] - ) + elif key == "cond_frames" or key == "cond_frames_without_noise": + batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0]) + elif key == "polars_rad" or key == "azimuths_rad": + batch[key] = torch.tensor(value_dict[key]).to(device).repeat(N[0]) else: batch[key] = value_dict[key] @@ -253,6 +315,7 @@ def load_model( device: str, num_frames: int, num_steps: int, + verbose: bool = False, ): config = OmegaConf.load(config) if device == "cuda": @@ -260,6 +323,7 @@ def load_model( 0 ].params.open_clip_embedding_config.params.init_device = device + config.model.params.sampler_config.params.verbose = verbose config.model.params.sampler_config.params.num_steps = num_steps config.model.params.sampler_config.params.guider_config.params.num_frames = ( num_frames diff --git a/sgm/modules/diffusionmodules/guiders.py b/sgm/modules/diffusionmodules/guiders.py index e8eca43e..bcaa01c5 100644 --- a/sgm/modules/diffusionmodules/guiders.py +++ b/sgm/modules/diffusionmodules/guiders.py @@ -1,6 +1,6 @@ import logging from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Literal, Optional, Tuple, Union import torch from einops import rearrange, repeat @@ -97,3 +97,35 @@ def prepare_inputs( assert c[k] == uc[k] c_out[k] = c[k] return torch.cat([x] * 2), torch.cat([s] * 2), c_out + + +class TrianglePredictionGuider(LinearPredictionGuider): + def __init__( + self, + max_scale: float, + num_frames: int, + min_scale: float = 1.0, + period: float | List[float] = 1.0, + period_fusing: Literal["mean", "multiply", "max"] = "max", + additional_cond_keys: Optional[Union[List[str], str]] = None, + ): + super().__init__(max_scale, num_frames, min_scale, additional_cond_keys) + values = torch.linspace(0, 1, num_frames) + # Constructs a triangle wave + if isinstance(period, float): + period = [period] + + scales = [] + for p in period: + scales.append(self.triangle_wave(values, p)) + + if period_fusing == "mean": + scale = sum(scales) / len(period) + elif period_fusing == "multiply": + scale = torch.prod(torch.stack(scales), dim=0) + elif period_fusing == "max": + scale = torch.max(torch.stack(scales), dim=0).values + self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0) + + def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor: + return 2 * (values / period - torch.floor(values / period + 0.5)).abs()