diff --git a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py index 1d1e0b76b..101a48d8b 100644 --- a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +++ b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py @@ -279,6 +279,7 @@ def __init__( dim_mults = (1, 2, 4, 8), channels = 3, self_condition = False, + masked_observations = False, resnet_block_groups = 8, learned_variance = False, learned_sinusoidal_cond = False, @@ -296,7 +297,10 @@ def __init__( self.channels = channels self.self_condition = self_condition - input_channels = channels * (2 if self_condition else 1) + self.masked_observations = masked_observations + # If we accept masked observations, add `channels` + 1 mask channel to the input + input_channels = channels + ((channels + 1) if self.masked_observations else 0) + input_channels = input_channels * (2 if self_condition else 1) init_dim = default(init_dim, dim) self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) @@ -385,13 +389,19 @@ def __init__( def downsample_factor(self): return 2 ** (len(self.downs) - 1) - def forward(self, x, time, x_self_cond = None): + def forward(self, x, time, x_self_cond = None, x_obs=None, x_obs_mask=None): assert all([divisible_by(d, self.downsample_factor) for d in x.shape[-2:]]), f'your input dimensions {x.shape[-2:]} need to be divisible by {self.downsample_factor}, given the unet' if self.self_condition: x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) x = torch.cat((x_self_cond, x), dim = 1) + if self.masked_observations: + batch, channel, height, width = x.shape + x_obs = default(x_obs, lambda: torch.zeros_like(x, device=x.device)) + x_obs_mask = default(x_obs_mask, lambda: torch.zeros((batch, 1, height, width), device=x.device)) + x = torch.cat((x, x_obs, x_obs_mask), dim=1) + x = self.init_conv(x) r = x.clone() @@ -496,6 +506,7 @@ def __init__( self.channels = self.model.channels self.self_condition = self.model.self_condition + self.masked_observations = self.model.masked_observations # make image size a tuple of (height, width) if isinstance(image_size, int): @@ -630,8 +641,8 @@ def q_posterior(self, x_start, x_t, t): posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False): - model_output = self.model(x, t, x_self_cond) + def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False, x_obs=None, x_obs_mask=None): + model_output = self.model(x, t, x_self_cond, x_obs, x_obs_mask) maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity if self.objective == 'pred_noise': @@ -655,8 +666,8 @@ def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rede return ModelPrediction(pred_noise, x_start) - def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True): - preds = self.model_predictions(x, t, x_self_cond) + def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True, x_obs=None, x_obs_mask=None): + preds = self.model_predictions(x, t, x_self_cond, x_obs=x_obs, x_obs_mask=x_obs_mask) x_start = preds.pred_x_start if clip_denoised: @@ -666,16 +677,16 @@ def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True): return model_mean, posterior_variance, posterior_log_variance, x_start @torch.inference_mode() - def p_sample(self, x, t: int, x_self_cond = None): + def p_sample(self, x, t: int, x_self_cond = None, x_obs=None, x_obs_mask=None): b, *_, device = *x.shape, self.device batched_times = torch.full((b,), t, device = device, dtype = torch.long) - model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True) + model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True, x_obs=x_obs, x_obs_mask=x_obs_mask) noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 pred_img = model_mean + (0.5 * model_log_variance).exp() * noise return pred_img, x_start @torch.inference_mode() - def p_sample_loop(self, shape, return_all_timesteps = False): + def p_sample_loop(self, shape, return_all_timesteps = False, x_obs=None, x_obs_mask=None): batch, device = shape[0], self.device img = torch.randn(shape, device = device) @@ -685,7 +696,7 @@ def p_sample_loop(self, shape, return_all_timesteps = False): for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): self_cond = x_start if self.self_condition else None - img, x_start = self.p_sample(img, t, self_cond) + img, x_start = self.p_sample(img, t, self_cond, x_obs=x_obs, x_obs_mask=x_obs_mask) imgs.append(img) ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) @@ -694,7 +705,7 @@ def p_sample_loop(self, shape, return_all_timesteps = False): return ret @torch.inference_mode() - def ddim_sample(self, shape, return_all_timesteps = False): + def ddim_sample(self, shape, return_all_timesteps = False, x_obs=None, x_obs_mask=None): batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps @@ -709,7 +720,7 @@ def ddim_sample(self, shape, return_all_timesteps = False): for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): time_cond = torch.full((batch,), time, device = device, dtype = torch.long) self_cond = x_start if self.self_condition else None - pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True) + pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True, x_obs=x_obs, x_obs_mask=x_obs_mask) if time_next < 0: img = x_start @@ -736,13 +747,16 @@ def ddim_sample(self, shape, return_all_timesteps = False): return ret @torch.inference_mode() - def sample(self, batch_size = 16, return_all_timesteps = False): + def sample(self, batch_size = 16, return_all_timesteps = False, x_obs=None, x_obs_mask=None): image_size, channels = self.image_size, self.channels sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample - return sample_fn((batch_size, channels, image_size[0], image_size[1]), return_all_timesteps = return_all_timesteps) + return sample_fn((batch_size, channels, image_size[0], image_size[1]), return_all_timesteps = return_all_timesteps, x_obs=x_obs, x_obs_mask=x_obs_mask) @torch.inference_mode() def interpolate(self, x1, x2, t = None, lam = 0.5): + if self.masked_observations: + # The call to p_sample needs an x_obs and x_obs_mask argument + raise NotImplementedError() b, *_, device = *x1.shape, x1.device t = default(t, self.num_timesteps - 1) @@ -770,7 +784,7 @@ def q_sample(self, x_start, t, noise = None): extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) - def p_losses(self, x_start, t, noise = None, offset_noise_strength = None): + def p_losses(self, x_start, t, noise = None, offset_noise_strength = None, x_obs=None, x_obs_mask=None): b, c, h, w = x_start.shape noise = default(noise, lambda: torch.randn_like(x_start)) @@ -794,12 +808,12 @@ def p_losses(self, x_start, t, noise = None, offset_noise_strength = None): x_self_cond = None if self.self_condition and random() < 0.5: with torch.no_grad(): - x_self_cond = self.model_predictions(x, t).pred_x_start + x_self_cond = self.model_predictions(x, t, x_obs=x_obs, x_obs_mask=x_obs_mask).pred_x_start x_self_cond.detach_() # predict and take gradient step - model_out = self.model(x, t, x_self_cond) + model_out = self.model(x, t, x_self_cond, x_obs, x_obs_mask) if self.objective == 'pred_noise': target = noise @@ -817,16 +831,24 @@ def p_losses(self, x_start, t, noise = None, offset_noise_strength = None): loss = loss * extract(self.loss_weight, t, loss.shape) return loss.mean() - def forward(self, img, *args, **kwargs): + def forward(self, img, *args, x_obs=None, x_obs_mask=None, **kwargs): b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size assert h == img_size[0] and w == img_size[1], f'height and width of image must be {img_size}' t = torch.randint(0, self.num_timesteps, (b,), device=device).long() img = self.normalize(img) - return self.p_losses(img, t, *args, **kwargs) + return self.p_losses(img, t, *args, x_obs=x_obs, x_obs_mask=x_obs_mask, **kwargs) # dataset classes +def generate_masks(mask_generator): + def __inner__(input_tensor): + out = mask_generator(input_tensor) + for x in ['img', 'obs', 'mask']: + assert hasattr(out, x) + return out + return __inner__ + class Dataset(Dataset): def __init__( self, @@ -835,7 +857,8 @@ def __init__( exts = ['jpg', 'jpeg', 'png', 'tiff'], interpolation = 'bilinear', augment_horizontal_flip = False, - convert_image_to = None + convert_image_to = None, + mask_generator = None, ): super().__init__() self.folder = folder @@ -850,7 +873,8 @@ def __init__( T.Resize(image_size, interpolation=interpolation), T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(), T.CenterCrop(image_size), - T.ToTensor() + T.ToTensor(), + T.Lambda(generate_masks(mask_generator)) if mask_generator else nn.Identity(), ]) def __len__(self): @@ -889,7 +913,8 @@ def __init__( inception_block_idx = 2048, max_grad_norm = 1., num_fid_samples = 50000, - save_best_and_latest_only = False + save_best_and_latest_only = False, + mask_generator = None, ): super().__init__() @@ -934,6 +959,7 @@ def __init__( interpolation = interpolation, augment_horizontal_flip = augment_horizontal_flip, convert_image_to = convert_image_to, + mask_generator= mask_generator, ) assert len(self.ds) >= 100, 'you should have at least 100 images in your folder. at least 10k images recommended' @@ -1045,7 +1071,11 @@ def train(self): data = next(self.dl).to(device) with self.accelerator.autocast(): - loss = self.model(data) + if self.model.masked_observation: + img, x_obs, x_obs_mask = data + loss = self.model(img, x_obs=x_obs, x_obs_mask=x_obs_mask) + else: + loss = self.model(data) loss = loss / self.gradient_accumulate_every total_loss += loss.item() @@ -1068,17 +1098,36 @@ def train(self): if self.step != 0 and divisible_by(self.step, self.save_and_sample_every): self.ema.ema_model.eval() + is_masked_observations = self.ema.ema_model.masked_observations with torch.inference_mode(): milestone = self.step // self.save_and_sample_every batches = num_to_groups(self.num_samples, self.batch_size) - all_images_list = list(map(lambda n: self.ema.ema_model.sample(batch_size=n), batches)) + additional_args = {} + if is_masked_observations: + num_masks = int(math.sqrt(self.num_samples)) + data = next(self.dl).to(device) - all_images = torch.cat(all_images_list, dim = 0) + imgs = data.img[:num_masks] + obs = data.obs[:num_masks] + masks = data.mask[:num_masks] + split_obs = torch.split(obs.repeat(num_masks, 1, 1, 1), batches) + split_masks = torch.split(masks.repeat(num_masks, 1, 1, 1), batches) - utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = int(math.sqrt(self.num_samples))) + additional_args_list = [{ + 'x_obs': o, + 'x_obs_mask': m, + } for o, m in zip(split_obs, split_masks)] - # whether to calculate fid + all_images_list = list(map(lambda args: self.ema.ema_model.sample(batch_size=args[0], **args[1]), zip(batches, additional_args_list))) + + all_images = torch.cat( + ([imgs, obs] if is_masked_observations else []) + + all_images_list, dim = 0) + num_rows = int(math.sqrt(self.num_samples)) + utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = num_rows) + + # whether to calculate fid if self.calculate_fid: fid_score = self.fid_scorer.fid_score() accelerator.print(f'fid_score: {fid_score}')