Skip to content

Commit

Permalink
fix: unet vid with bs > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Aug 29, 2024
1 parent 8c2f7ba commit 733ca37
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 54 deletions.
2 changes: 1 addition & 1 deletion examples/example_ddpm_vid_mario.json
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@
"pool_size": 50,
"save_by_iter": false,
"save_epoch_freq": 1,
"save_latest_freq": 5000,
"save_latest_freq": 1000,
"semantic_cls": false,
"semantic_mask": false,
"temporal_criterion": false,
Expand Down
39 changes: 15 additions & 24 deletions models/modules/diffusion_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ def restoration_ddpm(
ref,
):
phase = "test"
frame = 0 # To adapt a temporal video, convert the 5D tensor to a 4D tensor to align with the standard DDPM process
if len(y_cond.shape) == 5:
frame = y_cond.shape[1]
y_cond, y_t, y_0, mask = rearrange_5dto4d(y_cond, y_t, y_0, mask)

b, *_ = y_cond.shape

Expand All @@ -144,8 +140,6 @@ def restoration_ddpm(

y_t = self.default(y_t, lambda: torch.randn_like(y_cond))
ret_arr = y_t
if frame != 0: # the temporal video use 5D tensor as input
y_t, y_cond, mask = rearrange_4dto5d(frame, y_t, y_cond, mask)

for i in tqdm(
reversed(range(0, self.denoise_fn.model.num_timesteps_test)),
Expand All @@ -164,17 +158,12 @@ def restoration_ddpm(
ref=ref,
guidance_scale=guidance_scale,
)
if frame != 0:
mask = rearrange(mask, "b f c h w -> (b f) c h w")

if mask is not None:
temp_mask = torch.clamp(mask, min=0.0, max=1.0)
y_t = y_0 * (1.0 - temp_mask) + temp_mask * y_t
if i % sample_inter == 0:
ret_arr = torch.cat([ret_arr, y_t], dim=0)
if frame != 0:
y_t, mask = rearrange_4dto5d(frame, y_t, mask)
if frame != 0:
ret_arr = rearrange(ret_arr, "(b f) c h w -> b f c h w", f=frame)

return y_t, ret_arr

Expand Down Expand Up @@ -203,9 +192,9 @@ def p_mean_variance(
y_cond=None,
guidance_scale=0.0,
):
frame = 0
sequence_length = 0
if len(y_t.shape) == 5:
frame = y_t.shape[1]
sequence_length = y_t.shape[1]
y_t, y_cond, mask = rearrange_5dto4d(y_t, y_cond, mask)

noise_level = self.extract(
Expand All @@ -215,8 +204,8 @@ def p_mean_variance(
embed_noise_level = self.compute_gammas(noise_level)

input = torch.cat([y_cond, y_t], dim=1)
if frame != 0:
input, mask = rearrange_4dto5d(frame, input, mask)
if sequence_length != 0:
input, y_t, mask = rearrange_4dto5d(sequence_length, input, y_t, mask)

if guidance_scale > 0.0 and phase == "test":
y_0_hat_uncond = predict_start_from_noise(
Expand Down Expand Up @@ -268,6 +257,7 @@ def p_sample(
y_cond=None,
guidance_scale=0.0,
):

model_mean, model_log_variance = self.p_mean_variance(
y_t=y_t,
t=t,
Expand All @@ -279,12 +269,10 @@ def p_sample(
ref=ref,
guidance_scale=guidance_scale,
)
frame = 0
if len(y_t.shape) == 5:
y_t = rearrange(y_t, "b f c h w -> (b f) c h w")

noise = torch.randn_like(y_t) if any(t > 0) else torch.zeros_like(y_t)

out = model_mean + noise * (0.5 * model_log_variance).exp()

return out

## DDIM
Expand Down Expand Up @@ -450,9 +438,9 @@ def ddim_p_mean_variance(
return model_mean, posterior_log_variance

def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0):
frame = 0
sequence_length = 0
if len(y_0.shape) == 5:
frame = y_0.shape[1]
sequence_length = y_0.shape[1]
y_0, y_cond, mask = rearrange_5dto4d(y_0, y_cond, mask)

b, *_ = y_0.shape
Expand Down Expand Up @@ -483,11 +471,13 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0):

input = torch.cat([y_cond, y_noisy], dim=1)

if frame != 0:
input, mask = rearrange_4dto5d(frame, input, mask)
if sequence_length != 0:
input, mask, noise = rearrange_4dto5d(sequence_length, input, mask, noise)

noise_hat = self.denoise_fn(
input, embed_sample_gammas, cls=cls, mask=mask, ref=ref
)

# min-SNR loss weight
phase = "train"
ksnr = 5.0
Expand All @@ -506,6 +496,7 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0):
)
# reshape min_snr_loss_weight to match noise_hat
min_snr_loss_weight = min_snr_loss_weight.view(-1, 1, 1, 1)

return noise, noise_hat, min_snr_loss_weight

def set_new_sampling_method(self, sampling_method):
Expand Down
4 changes: 2 additions & 2 deletions models/modules/diffusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,11 @@ def extract(a, t, x_shape=(1, 1, 1, 1)):

def rearrange_5dto4d(*tensors):
"""Rearrange a tensor according to a given pattern using einops.rearrange."""
return [rearrange(tensor, "b f c h w -> (b f) c h w") for tensor in tensors]
return [rearrange(tensor, "b f c h w -> b c (f h) w") for tensor in tensors]


def rearrange_4dto5d(frame, *tensors):
"""Rearrange a tensor from 4D to 5D according to a given pattern using einops.rearrange."""
return [
rearrange(tensor, "(b f) c h w -> b f c h w", f=frame) for tensor in tensors
rearrange(tensor, "b c (f h) w -> b f c h w", f=frame) for tensor in tensors
]
53 changes: 30 additions & 23 deletions models/modules/unet_generator_attn/unet_generator_attn_vid.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,36 +240,49 @@ def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]

h = in_rest(x)
hd = in_rest(x)

if self.efficient and self.up:
h = in_conv(h)
h = self.h_upd(h)
hd = in_conv(hd)
hd = self.h_upd(hd)
x = self.x_upd(x)
else:
h = self.h_upd(h)
hd = self.h_upd(hd)
x = self.x_upd(x)
h = in_conv(h)

hd = in_conv(hd)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
hd = self.in_layers(x)

bf, c, h, w = hd.shape
b = bf // f
f = bf // b
hd = hd.view(b, f, c, h, w)

emb_out = self.emb_layers(emb).type(hd.dtype)
emb_out = emb_out.unsqueeze(1)
while len(emb_out.shape) < len(hd.shape):
emb_out = emb_out.unsqueeze(-1)
# emb_out = emb_out[..., None]

if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
scale, shift = torch.chunk(emb_out, 2, dim=2)

##TODO: simplify by receiving the embedding in 'packed' shape ?
hd = hd.view(b * f, c, h, w)
hd = out_norm(hd)
hd = hd.view(b, f, c, h, w)
hd *= (1 + scale) + shift
hd = hd.view(b * f, c, h, w)
hd = out_rest(hd)
else:
h = h + emb_out
h = self.out_layers(h)
hd = hd + emb_out
hd = hd.view(b * f, c, h, w)
hd = self.out_layers(hd)

skipw = 1.0
if self.efficient:
skipw = 1.0 / math.sqrt(2)
output = self.skip_connection(x) + h
output = self.skip_connection(x) + hd
bf, c, h, w = output.shape
b = bf // f
f = bf // b
Expand Down Expand Up @@ -1355,11 +1368,6 @@ def __init__(
}

def compute_feats(self, input, embed_gammas):
if embed_gammas is None:
# Only for GAN
b = (input.shape[0], self.cond_embed_dim)
embed_gammas = torch.ones(b).to(input.device)

emb = embed_gammas

hs = []
Expand All @@ -1378,7 +1386,6 @@ def compute_feats(self, input, embed_gammas):
return outs, feats, emb

def forward(self, input, embed_gammas=None):

h, hs, emb = self.compute_feats(input, embed_gammas=embed_gammas)
for i, module in enumerate(self.output_blocks):
h = torch.cat([h, hs.pop()], dim=2)
Expand All @@ -1392,7 +1399,7 @@ def forward(self, input, embed_gammas=None):
if self.freq_space:
outh = self.iwt(outh)
outh = outh.reshape(b, f, -1, h_dim, w_dim)
outh = rearrange(outh, "b f c h w -> (b f) c h w")

return outh

def get_feats(self, input, extract_layer_ids):
Expand Down
3 changes: 0 additions & 3 deletions models/palette_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from .base_diffusion_model import BaseDiffusionModel
from .modules.loss import MultiScaleDiffusionLoss
from .modules.unet_generator_attn.unet_attn_utils import revert_sync_batchnorm
from einops import rearrange


class PaletteModel(BaseDiffusionModel):
Expand Down Expand Up @@ -467,7 +466,6 @@ def compute_palette_loss(self):
frame = 0
if len(y_0.shape) == 5:
frame = y_0.shape[1]
mask = rearrange(mask, "b f c h w -> (b f) c h w")

if not self.opt.alg_palette_minsnr:
min_snr_loss_weight = 1.0
Expand All @@ -478,7 +476,6 @@ def compute_palette_loss(self):
min_snr_loss_weight * mask_binary * noise,
min_snr_loss_weight * mask_binary * noise_hat,
)

else:
loss = self.loss_fn(
min_snr_loss_weight * noise, min_snr_loss_weight * noise_hat
Expand Down
8 changes: 7 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,20 @@
from util.visualizer import Visualizer
from util.lion_pytorch import Lion
from util.script import get_override_options_names
import datetime


def setup(rank, world_size, port):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = port

# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
dist.init_process_group(
"nccl",
rank=rank,
world_size=world_size,
timeout=datetime.timedelta(seconds=5400),
) # modified timeout from default 10 or 30 mins (?) to 1.5h


def optim(opt, params, lr, betas, weight_decay, eps):
Expand Down

0 comments on commit 733ca37

Please sign in to comment.