Skip to content

Commit

Permalink
Add conditional model (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
ewfuentes authored Apr 3, 2024
1 parent b9be0dc commit 23b0151
Showing 1 changed file with 76 additions and 27 deletions.
103 changes: 76 additions & 27 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ def __init__(
dim_mults = (1, 2, 4, 8),
channels = 3,
self_condition = False,
masked_observations = False,
resnet_block_groups = 8,
learned_variance = False,
learned_sinusoidal_cond = False,
Expand All @@ -296,7 +297,10 @@ def __init__(

self.channels = channels
self.self_condition = self_condition
input_channels = channels * (2 if self_condition else 1)
self.masked_observations = masked_observations
# If we accept masked observations, add `channels` + 1 mask channel to the input
input_channels = channels + ((channels + 1) if self.masked_observations else 0)
input_channels = input_channels * (2 if self_condition else 1)

init_dim = default(init_dim, dim)
self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3)
Expand Down Expand Up @@ -385,13 +389,19 @@ def __init__(
def downsample_factor(self):
return 2 ** (len(self.downs) - 1)

def forward(self, x, time, x_self_cond = None):
def forward(self, x, time, x_self_cond = None, x_obs=None, x_obs_mask=None):
assert all([divisible_by(d, self.downsample_factor) for d in x.shape[-2:]]), f'your input dimensions {x.shape[-2:]} need to be divisible by {self.downsample_factor}, given the unet'

if self.self_condition:
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x_self_cond, x), dim = 1)

if self.masked_observations:
batch, channel, height, width = x.shape
x_obs = default(x_obs, lambda: torch.zeros_like(x, device=x.device))
x_obs_mask = default(x_obs_mask, lambda: torch.zeros((batch, 1, height, width), device=x.device))
x = torch.cat((x, x_obs, x_obs_mask), dim=1)

x = self.init_conv(x)
r = x.clone()

Expand Down Expand Up @@ -496,6 +506,7 @@ def __init__(

self.channels = self.model.channels
self.self_condition = self.model.self_condition
self.masked_observations = self.model.masked_observations

# make image size a tuple of (height, width)
if isinstance(image_size, int):
Expand Down Expand Up @@ -630,8 +641,8 @@ def q_posterior(self, x_start, x_t, t):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped

def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):
model_output = self.model(x, t, x_self_cond)
def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False, x_obs=None, x_obs_mask=None):
model_output = self.model(x, t, x_self_cond, x_obs, x_obs_mask)
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity

if self.objective == 'pred_noise':
Expand All @@ -655,8 +666,8 @@ def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rede

return ModelPrediction(pred_noise, x_start)

def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
preds = self.model_predictions(x, t, x_self_cond)
def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True, x_obs=None, x_obs_mask=None):
preds = self.model_predictions(x, t, x_self_cond, x_obs=x_obs, x_obs_mask=x_obs_mask)
x_start = preds.pred_x_start

if clip_denoised:
Expand All @@ -666,16 +677,16 @@ def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
return model_mean, posterior_variance, posterior_log_variance, x_start

@torch.inference_mode()
def p_sample(self, x, t: int, x_self_cond = None):
def p_sample(self, x, t: int, x_self_cond = None, x_obs=None, x_obs_mask=None):
b, *_, device = *x.shape, self.device
batched_times = torch.full((b,), t, device = device, dtype = torch.long)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True, x_obs=x_obs, x_obs_mask=x_obs_mask)
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
return pred_img, x_start

@torch.inference_mode()
def p_sample_loop(self, shape, return_all_timesteps = False):
def p_sample_loop(self, shape, return_all_timesteps = False, x_obs=None, x_obs_mask=None):
batch, device = shape[0], self.device

img = torch.randn(shape, device = device)
Expand All @@ -685,7 +696,7 @@ def p_sample_loop(self, shape, return_all_timesteps = False):

for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
self_cond = x_start if self.self_condition else None
img, x_start = self.p_sample(img, t, self_cond)
img, x_start = self.p_sample(img, t, self_cond, x_obs=x_obs, x_obs_mask=x_obs_mask)
imgs.append(img)

ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
Expand All @@ -694,7 +705,7 @@ def p_sample_loop(self, shape, return_all_timesteps = False):
return ret

@torch.inference_mode()
def ddim_sample(self, shape, return_all_timesteps = False):
def ddim_sample(self, shape, return_all_timesteps = False, x_obs=None, x_obs_mask=None):
batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective

times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
Expand All @@ -709,7 +720,7 @@ def ddim_sample(self, shape, return_all_timesteps = False):
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
self_cond = x_start if self.self_condition else None
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True)
pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True, x_obs=x_obs, x_obs_mask=x_obs_mask)

if time_next < 0:
img = x_start
Expand All @@ -736,13 +747,16 @@ def ddim_sample(self, shape, return_all_timesteps = False):
return ret

@torch.inference_mode()
def sample(self, batch_size = 16, return_all_timesteps = False):
def sample(self, batch_size = 16, return_all_timesteps = False, x_obs=None, x_obs_mask=None):
image_size, channels = self.image_size, self.channels
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
return sample_fn((batch_size, channels, image_size[0], image_size[1]), return_all_timesteps = return_all_timesteps)
return sample_fn((batch_size, channels, image_size[0], image_size[1]), return_all_timesteps = return_all_timesteps, x_obs=x_obs, x_obs_mask=x_obs_mask)

@torch.inference_mode()
def interpolate(self, x1, x2, t = None, lam = 0.5):
if self.masked_observations:
# The call to p_sample needs an x_obs and x_obs_mask argument
raise NotImplementedError()
b, *_, device = *x1.shape, x1.device
t = default(t, self.num_timesteps - 1)

Expand Down Expand Up @@ -770,7 +784,7 @@ def q_sample(self, x_start, t, noise = None):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)

def p_losses(self, x_start, t, noise = None, offset_noise_strength = None):
def p_losses(self, x_start, t, noise = None, offset_noise_strength = None, x_obs=None, x_obs_mask=None):
b, c, h, w = x_start.shape

noise = default(noise, lambda: torch.randn_like(x_start))
Expand All @@ -794,12 +808,12 @@ def p_losses(self, x_start, t, noise = None, offset_noise_strength = None):
x_self_cond = None
if self.self_condition and random() < 0.5:
with torch.no_grad():
x_self_cond = self.model_predictions(x, t).pred_x_start
x_self_cond = self.model_predictions(x, t, x_obs=x_obs, x_obs_mask=x_obs_mask).pred_x_start
x_self_cond.detach_()

# predict and take gradient step

model_out = self.model(x, t, x_self_cond)
model_out = self.model(x, t, x_self_cond, x_obs, x_obs_mask)

if self.objective == 'pred_noise':
target = noise
Expand All @@ -817,16 +831,24 @@ def p_losses(self, x_start, t, noise = None, offset_noise_strength = None):
loss = loss * extract(self.loss_weight, t, loss.shape)
return loss.mean()

def forward(self, img, *args, **kwargs):
def forward(self, img, *args, x_obs=None, x_obs_mask=None, **kwargs):
b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
assert h == img_size[0] and w == img_size[1], f'height and width of image must be {img_size}'
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()

img = self.normalize(img)
return self.p_losses(img, t, *args, **kwargs)
return self.p_losses(img, t, *args, x_obs=x_obs, x_obs_mask=x_obs_mask, **kwargs)

# dataset classes

def generate_masks(mask_generator):
def __inner__(input_tensor):
out = mask_generator(input_tensor)
for x in ['img', 'obs', 'mask']:
assert hasattr(out, x)
return out
return __inner__

class Dataset(Dataset):
def __init__(
self,
Expand All @@ -835,7 +857,8 @@ def __init__(
exts = ['jpg', 'jpeg', 'png', 'tiff'],
interpolation = 'bilinear',
augment_horizontal_flip = False,
convert_image_to = None
convert_image_to = None,
mask_generator = None,
):
super().__init__()
self.folder = folder
Expand All @@ -850,7 +873,8 @@ def __init__(
T.Resize(image_size, interpolation=interpolation),
T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
T.CenterCrop(image_size),
T.ToTensor()
T.ToTensor(),
T.Lambda(generate_masks(mask_generator)) if mask_generator else nn.Identity(),
])

def __len__(self):
Expand Down Expand Up @@ -889,7 +913,8 @@ def __init__(
inception_block_idx = 2048,
max_grad_norm = 1.,
num_fid_samples = 50000,
save_best_and_latest_only = False
save_best_and_latest_only = False,
mask_generator = None,
):
super().__init__()

Expand Down Expand Up @@ -934,6 +959,7 @@ def __init__(
interpolation = interpolation,
augment_horizontal_flip = augment_horizontal_flip,
convert_image_to = convert_image_to,
mask_generator= mask_generator,
)

assert len(self.ds) >= 100, 'you should have at least 100 images in your folder. at least 10k images recommended'
Expand Down Expand Up @@ -1045,7 +1071,11 @@ def train(self):
data = next(self.dl).to(device)

with self.accelerator.autocast():
loss = self.model(data)
if self.model.masked_observation:
img, x_obs, x_obs_mask = data
loss = self.model(img, x_obs=x_obs, x_obs_mask=x_obs_mask)
else:
loss = self.model(data)
loss = loss / self.gradient_accumulate_every
total_loss += loss.item()

Expand All @@ -1068,17 +1098,36 @@ def train(self):
if self.step != 0 and divisible_by(self.step, self.save_and_sample_every):
self.ema.ema_model.eval()

is_masked_observations = self.ema.ema_model.masked_observations
with torch.inference_mode():
milestone = self.step // self.save_and_sample_every
batches = num_to_groups(self.num_samples, self.batch_size)
all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches))
additional_args = {}
if is_masked_observations:
num_masks = int(math.sqrt(self.num_samples))
data = next(self.dl).to(device)

all_images = torch.cat(all_images_list, dim = 0)
imgs = data.img[:num_masks]
obs = data.obs[:num_masks]
masks = data.mask[:num_masks]
split_obs = torch.split(obs.repeat(num_masks, 1, 1, 1), batches)
split_masks = torch.split(masks.repeat(num_masks, 1, 1, 1), batches)

utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = int(math.sqrt(self.num_samples)))
additional_args_list = [{
'x_obs': o,
'x_obs_mask': m,
} for o, m in zip(split_obs, split_masks)]

# whether to calculate fid
all_images_list = list(map(lambda args: self.ema.ema_model.sample(batch_size=args[0], **args[1]), zip(batches, additional_args_list)))

all_images = torch.cat(
([imgs, obs] if is_masked_observations else []) +
all_images_list, dim = 0)

num_rows = int(math.sqrt(self.num_samples))
utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = num_rows)

# whether to calculate fid
if self.calculate_fid:
fid_score = self.fid_scorer.fid_score()
accelerator.print(f'fid_score: {fid_score}')
Expand Down

0 comments on commit 23b0151

Please sign in to comment.