diff --git a/.gitignore b/.gitignore index 5506c38d..8623d1a9 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ /dist /outputs /build -/src \ No newline at end of file +/src +/.vscode \ No newline at end of file diff --git a/README.md b/README.md old mode 100644 new mode 100755 index 71c0b868..d9b3714c --- a/README.md +++ b/README.md @@ -23,7 +23,8 @@ To run **SV4D** on a single input video of 21 frames: - `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time. - `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p. - `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0` - - **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos (with noisy background), try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D. + - **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D. + - **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`. ![tile](assets/sv4d.gif) diff --git a/scripts/demo/sv4d_helpers.py b/scripts/demo/sv4d_helpers.py old mode 100644 new mode 100755 index b533b5e5..d67231df --- a/scripts/demo/sv4d_helpers.py +++ b/scripts/demo/sv4d_helpers.py @@ -36,6 +36,20 @@ from sgm.util import default, instantiate_from_config +def load_module_gpu(model): + model.cuda() + + +def unload_module_gpu(model): + model.cpu() + torch.cuda.empty_cache() + + +def initial_model_load(model): + model.model.half() + return model + + def get_resizing_factor( desired_shape: Tuple[int, int], current_shape: Tuple[int, int] ) -> float: @@ -60,75 +74,11 @@ def get_resizing_factor( return factor -def load_img_for_prediction_no_st( - image_path: str, - mask_path: str, - W: int, - H: int, - crop_h: int, - crop_w: int, - device="cuda", -) -> torch.Tensor: - image = Image.open(image_path) - if image is None: - return None - image = np.array(image).astype(np.float32) / 255 - h, w = image.shape[:2] - rotated = 0 - - mask = None - if mask_path is not None: - mask = Image.open(mask_path) - mask = np.array(mask).astype(np.float32) / 255 - mask = np.any(mask.reshape(h, w, -1) > 0, axis=2, keepdims=True).astype( - np.float32 - ) - elif image.shape[-1] == 4: - mask = image[:, :, 3:] - - if mask is not None: - image = image[:, :, :3] * mask + (1 - mask) - # if "DAVIS" in image_path: - # y, x, _ = np.where(mask > 0) - # x_mean, y_mean = np.mean(x), np.mean(y) - # else: - # x_mean, y_mean = w//2, h//2 - # h_new = int(max(crop_h, crop_w) * 1.33) - # x_min = max(int(x_mean - h_new//2), 0) - # y_min = max(int(y_mean - h_new//2), 0) - # image_cropped = image[y_min : y_min + h_new, x_min : x_min + h_new] - # h_crop, w_crop = image_cropped.shape[:2] - # h_new = max(h_crop, w_crop) - # top = max((h_new - h_crop) // 2, 0) - # left = max((h_new - w_crop) // 2, 0) - # image_padded = np.ones((h_new, h_new, 3)).astype(np.float32) - # image_padded[top : top + h_crop, left : left + w_crop, :] = image_cropped - # image = image_padded - # h, w = image.shape[:2] - - image = image.transpose(2, 0, 1) - image = torch.from_numpy(image).to(dtype=torch.float32) - image = image.unsqueeze(0) - - rfs = get_resizing_factor((H, W), (h, w)) - resize_size = [int(np.ceil(rfs * s)) for s in (h, w)] - top = (resize_size[0] - H) // 2 - left = (resize_size[1] - W) // 2 - - image = torch.nn.functional.interpolate( - image, resize_size, mode="area", antialias=False - ) - image = TT.functional.crop(image, top=top, left=left, height=H, width=W) - return image.to(device) * 2.0 - 1.0, rotated - - def read_gif(input_path, n_frames): frames = [] video = Image.open(input_path) - if video.n_frames < n_frames: - return frames for img in ImageSequence.Iterator(video): - frames.append(img.convert("RGB")) + frames.append(img.convert("RGBA")) if len(frames) == n_frames: break return frames @@ -206,16 +156,17 @@ def read_video( print(f"Loading {len(all_img_paths)} video frames...") images = [Image.open(img_path) for img_path in all_img_paths] + if len(images) < n_frames: + images = (images + images[::-1])[:n_frames] + if len(images) != n_frames: - raise ValueError("Input video contains fewer than {n_frames} frames.") + raise ValueError(f"Input video contains fewer than {n_frames} frames.") # Remove background and crop video frames images_v0 = [] - for image in images: + for t, image in enumerate(images): if remove_bg: - if image.mode == "RGBA": - pass - else: + if image.mode != "RGBA": image.thumbnail([W, H], Image.Resampling.LANCZOS) image = remove(image.convert("RGBA"), alpha_matting=True) image_arr = np.array(image) @@ -225,11 +176,12 @@ def read_video( ) x, y, w, h = cv2.boundingRect(mask) max_size = max(w, h) - side_len = ( - int(max_size / image_frame_ratio) - if image_frame_ratio is not None - else in_w - ) + if t == 0: + side_len = ( + int(max_size / image_frame_ratio) + if image_frame_ratio is not None + else in_w + ) padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) center = side_len // 2 padded_image[ @@ -239,7 +191,9 @@ def read_video( rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS) rgba_arr = np.array(rgba) / 255.0 rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:]) - images = Image.fromarray((rgb * 255).astype(np.uint8)) + image = Image.fromarray((rgb * 255).astype(np.uint8)) + else: + image = image.convert("RGB").resize((W, H), Image.LANCZOS) image = ToTensor()(image).unsqueeze(0).to(device) images_v0.append(image * 2.0 - 1.0) return images_v0 @@ -341,11 +295,13 @@ def denoiser(input, sigma, c): def decode_latents(model, samples_z, timesteps): + load_module_gpu(model.first_stage_model) if isinstance(model.first_stage_model.decoder, VideoDecoder): samples_x = model.decode_first_stage(samples_z, timesteps=timesteps) else: samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + unload_module_gpu(model.first_stage_model) return samples @@ -751,6 +707,7 @@ def do_sample( else: num_samples = [num_samples] + load_module_gpu(model.conditioner) batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, @@ -758,13 +715,13 @@ def do_sample( T=T, additional_batch_uc_fields=additional_batch_uc_fields, ) - c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, force_cond_zero_embeddings=force_cond_zero_embeddings, ) + unload_module_gpu(model.conditioner) for k in c: if not k == "crossattn": @@ -805,8 +762,13 @@ def denoiser(input, sigma, c): model.model, input, sigma, c, **additional_model_inputs ) + load_module_gpu(model.model) + load_module_gpu(model.denoiser) samples_z = sampler(denoiser, randn, cond=c, uc=uc) + unload_module_gpu(model.model) + unload_module_gpu(model.denoiser) + load_module_gpu(model.first_stage_model) if isinstance(model.first_stage_model.decoder, VideoDecoder): samples_x = model.decode_first_stage( samples_z, timesteps=default(decoding_t, T) @@ -814,6 +776,7 @@ def denoiser(input, sigma, c): else: samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + unload_module_gpu(model.first_stage_model) if filter is not None: samples = filter(samples) @@ -850,6 +813,7 @@ def do_sample_per_step( else: num_samples = [num_samples] + load_module_gpu(model.conditioner) batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, @@ -857,13 +821,13 @@ def do_sample_per_step( T=T, additional_batch_uc_fields=additional_batch_uc_fields, ) - c, uc = model.conditioner.get_unconditional_conditioning( batch, batch_uc=batch_uc, force_uc_zero_embeddings=force_uc_zero_embeddings, force_cond_zero_embeddings=force_cond_zero_embeddings, ) + unload_module_gpu(model.conditioner) for k in c: if not k == "crossattn": @@ -917,6 +881,9 @@ def denoiser(input, sigma, c): if sampler.s_tmin <= sigmas[step] <= sampler.s_tmax else 0.0 ) + + load_module_gpu(model.model) + load_module_gpu(model.denoiser) samples_z = sampler.sampler_step( s_in * sigmas[step], s_in * sigmas[step + 1], @@ -926,6 +893,8 @@ def denoiser(input, sigma, c): uc, gamma, ) + unload_module_gpu(model.model) + unload_module_gpu(model.denoiser) return samples_z diff --git a/scripts/sampling/configs/sv4d.yaml b/scripts/sampling/configs/sv4d.yaml old mode 100644 new mode 100755 index b908b758..56e6ddb3 --- a/scripts/sampling/configs/sv4d.yaml +++ b/scripts/sampling/configs/sv4d.yaml @@ -93,12 +93,6 @@ model: sigma_sampler_config: target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler - # - input_key: cond_aug - # is_trainable: False - # target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - # params: - # outdim: 256 - - input_key: polar_rad is_trainable: False target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND diff --git a/scripts/sampling/simple_video_sample_4d.py b/scripts/sampling/simple_video_sample_4d.py old mode 100644 new mode 100755 index 51c809ea..c970e74f --- a/scripts/sampling/simple_video_sample_4d.py +++ b/scripts/sampling/simple_video_sample_4d.py @@ -13,6 +13,7 @@ from scripts.demo.sv4d_helpers import ( decode_latents, load_model, + initial_model_load, read_video, run_img2vid, run_img2vid_per_step, @@ -26,6 +27,7 @@ def sample( output_folder: Optional[str] = "outputs/sv4d", num_steps: Optional[int] = 20, sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p + img_size: int = 576, # image resolution fps_id: int = 6, motion_bucket_id: int = 127, cond_aug: float = 1e-5, @@ -47,7 +49,7 @@ def sample( V = 8 # number of views per sample F = 8 # vae factor to downsize image->latent C = 4 - H, W = 576, 576 + H, W = img_size, img_size n_frames = 21 # number of input and output video frames n_views = V + 1 # number of output video views (1 input view + 8 novel views) n_views_sv3d = 21 @@ -64,7 +66,7 @@ def sample( "f": F, "options": { "discretization": 1, - "cfg": 2.5, + "cfg": 3.0, "sigma_min": 0.002, "sigma_max": 700.0, "rho": 7.0, @@ -137,7 +139,7 @@ def sample( for t in range(n_frames): img_matrix[t][0] = images_v0[t] - base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 10 + base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 11 save_video( os.path.join(output_folder, f"{base_count:06d}_t000.mp4"), img_matrix[0], @@ -155,6 +157,7 @@ def sample( num_steps, verbose, ) + model = initial_model_load(model) # Interleaved sampling for anchor frames t0, v0 = 0, 0 diff --git a/sgm/modules/spacetime_attention.py b/sgm/modules/spacetime_attention.py index c604c1b8..1ca44bb2 100644 --- a/sgm/modules/spacetime_attention.py +++ b/sgm/modules/spacetime_attention.py @@ -593,4 +593,4 @@ def forward( if not self.use_linear: x = self.proj_out(x) out = x + x_in - return out + return out \ No newline at end of file