Skip to content

Commit

Permalink
Merge pull request #392 from Stability-AI/chunhan/sv4d
Browse files Browse the repository at this point in the history
update sv4d sampling script and readme
  • Loading branch information
voletiv authored Aug 2, 2024
2 parents 8636655 + ce1576b commit e0596f1
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 90 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
/dist
/outputs
/build
/src
/src
/.vscode
3 changes: 2 additions & 1 deletion README.md
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ To run **SV4D** on a single input video of 21 frames:
- `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time.
- `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p.
- `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0`
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos (with noisy background), try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D.
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D.
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.

![tile](assets/sv4d.gif)

Expand Down
125 changes: 47 additions & 78 deletions scripts/demo/sv4d_helpers.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@
from sgm.util import default, instantiate_from_config


def load_module_gpu(model):
model.cuda()


def unload_module_gpu(model):
model.cpu()
torch.cuda.empty_cache()


def initial_model_load(model):
model.model.half()
return model


def get_resizing_factor(
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
) -> float:
Expand All @@ -60,75 +74,11 @@ def get_resizing_factor(
return factor


def load_img_for_prediction_no_st(
image_path: str,
mask_path: str,
W: int,
H: int,
crop_h: int,
crop_w: int,
device="cuda",
) -> torch.Tensor:
image = Image.open(image_path)
if image is None:
return None
image = np.array(image).astype(np.float32) / 255
h, w = image.shape[:2]
rotated = 0

mask = None
if mask_path is not None:
mask = Image.open(mask_path)
mask = np.array(mask).astype(np.float32) / 255
mask = np.any(mask.reshape(h, w, -1) > 0, axis=2, keepdims=True).astype(
np.float32
)
elif image.shape[-1] == 4:
mask = image[:, :, 3:]

if mask is not None:
image = image[:, :, :3] * mask + (1 - mask)
# if "DAVIS" in image_path:
# y, x, _ = np.where(mask > 0)
# x_mean, y_mean = np.mean(x), np.mean(y)
# else:
# x_mean, y_mean = w//2, h//2
# h_new = int(max(crop_h, crop_w) * 1.33)
# x_min = max(int(x_mean - h_new//2), 0)
# y_min = max(int(y_mean - h_new//2), 0)
# image_cropped = image[y_min : y_min + h_new, x_min : x_min + h_new]
# h_crop, w_crop = image_cropped.shape[:2]
# h_new = max(h_crop, w_crop)
# top = max((h_new - h_crop) // 2, 0)
# left = max((h_new - w_crop) // 2, 0)
# image_padded = np.ones((h_new, h_new, 3)).astype(np.float32)
# image_padded[top : top + h_crop, left : left + w_crop, :] = image_cropped
# image = image_padded
# h, w = image.shape[:2]

image = image.transpose(2, 0, 1)
image = torch.from_numpy(image).to(dtype=torch.float32)
image = image.unsqueeze(0)

rfs = get_resizing_factor((H, W), (h, w))
resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
top = (resize_size[0] - H) // 2
left = (resize_size[1] - W) // 2

image = torch.nn.functional.interpolate(
image, resize_size, mode="area", antialias=False
)
image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
return image.to(device) * 2.0 - 1.0, rotated


def read_gif(input_path, n_frames):
frames = []
video = Image.open(input_path)
if video.n_frames < n_frames:
return frames
for img in ImageSequence.Iterator(video):
frames.append(img.convert("RGB"))
frames.append(img.convert("RGBA"))
if len(frames) == n_frames:
break
return frames
Expand Down Expand Up @@ -206,16 +156,17 @@ def read_video(
print(f"Loading {len(all_img_paths)} video frames...")
images = [Image.open(img_path) for img_path in all_img_paths]

if len(images) < n_frames:
images = (images + images[::-1])[:n_frames]

if len(images) != n_frames:
raise ValueError("Input video contains fewer than {n_frames} frames.")
raise ValueError(f"Input video contains fewer than {n_frames} frames.")

# Remove background and crop video frames
images_v0 = []
for image in images:
for t, image in enumerate(images):
if remove_bg:
if image.mode == "RGBA":
pass
else:
if image.mode != "RGBA":
image.thumbnail([W, H], Image.Resampling.LANCZOS)
image = remove(image.convert("RGBA"), alpha_matting=True)
image_arr = np.array(image)
Expand All @@ -225,11 +176,12 @@ def read_video(
)
x, y, w, h = cv2.boundingRect(mask)
max_size = max(w, h)
side_len = (
int(max_size / image_frame_ratio)
if image_frame_ratio is not None
else in_w
)
if t == 0:
side_len = (
int(max_size / image_frame_ratio)
if image_frame_ratio is not None
else in_w
)
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
center = side_len // 2
padded_image[
Expand All @@ -239,7 +191,9 @@ def read_video(
rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS)
rgba_arr = np.array(rgba) / 255.0
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
images = Image.fromarray((rgb * 255).astype(np.uint8))
image = Image.fromarray((rgb * 255).astype(np.uint8))
else:
image = image.convert("RGB").resize((W, H), Image.LANCZOS)
image = ToTensor()(image).unsqueeze(0).to(device)
images_v0.append(image * 2.0 - 1.0)
return images_v0
Expand Down Expand Up @@ -341,11 +295,13 @@ def denoiser(input, sigma, c):


def decode_latents(model, samples_z, timesteps):
load_module_gpu(model.first_stage_model)
if isinstance(model.first_stage_model.decoder, VideoDecoder):
samples_x = model.decode_first_stage(samples_z, timesteps=timesteps)
else:
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_module_gpu(model.first_stage_model)
return samples


Expand Down Expand Up @@ -751,20 +707,21 @@ def do_sample(
else:
num_samples = [num_samples]

load_module_gpu(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
T=T,
additional_batch_uc_fields=additional_batch_uc_fields,
)

c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
force_cond_zero_embeddings=force_cond_zero_embeddings,
)
unload_module_gpu(model.conditioner)

for k in c:
if not k == "crossattn":
Expand Down Expand Up @@ -805,15 +762,21 @@ def denoiser(input, sigma, c):
model.model, input, sigma, c, **additional_model_inputs
)

load_module_gpu(model.model)
load_module_gpu(model.denoiser)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
unload_module_gpu(model.model)
unload_module_gpu(model.denoiser)

load_module_gpu(model.first_stage_model)
if isinstance(model.first_stage_model.decoder, VideoDecoder):
samples_x = model.decode_first_stage(
samples_z, timesteps=default(decoding_t, T)
)
else:
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_module_gpu(model.first_stage_model)

if filter is not None:
samples = filter(samples)
Expand Down Expand Up @@ -850,20 +813,21 @@ def do_sample_per_step(
else:
num_samples = [num_samples]

load_module_gpu(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
T=T,
additional_batch_uc_fields=additional_batch_uc_fields,
)

c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
force_cond_zero_embeddings=force_cond_zero_embeddings,
)
unload_module_gpu(model.conditioner)

for k in c:
if not k == "crossattn":
Expand Down Expand Up @@ -917,6 +881,9 @@ def denoiser(input, sigma, c):
if sampler.s_tmin <= sigmas[step] <= sampler.s_tmax
else 0.0
)

load_module_gpu(model.model)
load_module_gpu(model.denoiser)
samples_z = sampler.sampler_step(
s_in * sigmas[step],
s_in * sigmas[step + 1],
Expand All @@ -926,6 +893,8 @@ def denoiser(input, sigma, c):
uc,
gamma,
)
unload_module_gpu(model.model)
unload_module_gpu(model.denoiser)

return samples_z

Expand Down
6 changes: 0 additions & 6 deletions scripts/sampling/configs/sv4d.yaml
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,6 @@ model:
sigma_sampler_config:
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler

# - input_key: cond_aug
# is_trainable: False
# target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
# params:
# outdim: 256

- input_key: polar_rad
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
Expand Down
9 changes: 6 additions & 3 deletions scripts/sampling/simple_video_sample_4d.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from scripts.demo.sv4d_helpers import (
decode_latents,
load_model,
initial_model_load,
read_video,
run_img2vid,
run_img2vid_per_step,
Expand All @@ -26,6 +27,7 @@ def sample(
output_folder: Optional[str] = "outputs/sv4d",
num_steps: Optional[int] = 20,
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
img_size: int = 576, # image resolution
fps_id: int = 6,
motion_bucket_id: int = 127,
cond_aug: float = 1e-5,
Expand All @@ -47,7 +49,7 @@ def sample(
V = 8 # number of views per sample
F = 8 # vae factor to downsize image->latent
C = 4
H, W = 576, 576
H, W = img_size, img_size
n_frames = 21 # number of input and output video frames
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
n_views_sv3d = 21
Expand All @@ -64,7 +66,7 @@ def sample(
"f": F,
"options": {
"discretization": 1,
"cfg": 2.5,
"cfg": 3.0,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
Expand Down Expand Up @@ -137,7 +139,7 @@ def sample(
for t in range(n_frames):
img_matrix[t][0] = images_v0[t]

base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 10
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 11
save_video(
os.path.join(output_folder, f"{base_count:06d}_t000.mp4"),
img_matrix[0],
Expand All @@ -155,6 +157,7 @@ def sample(
num_steps,
verbose,
)
model = initial_model_load(model)

# Interleaved sampling for anchor frames
t0, v0 = 0, 0
Expand Down
2 changes: 1 addition & 1 deletion sgm/modules/spacetime_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,4 +593,4 @@ def forward(
if not self.use_linear:
x = self.proj_out(x)
out = x + x_in
return out
return out

0 comments on commit e0596f1

Please sign in to comment.