From 28e929606b00e7778f507006017c50aa12d88ab4 Mon Sep 17 00:00:00 2001 From: Erick Fuentes Date: Tue, 2 Apr 2024 21:33:50 -0400 Subject: [PATCH] fix check for masked observations during training --- denoising_diffusion_pytorch/denoising_diffusion_pytorch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py index 101a48d8b..4c4d0d53c 100644 --- a/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +++ b/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py @@ -1060,6 +1060,7 @@ def load(self, milestone): def train(self): accelerator = self.accelerator device = accelerator.device + is_masked_observations = self.ema.ema_model.masked_observations with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: @@ -1071,7 +1072,7 @@ def train(self): data = next(self.dl).to(device) with self.accelerator.autocast(): - if self.model.masked_observation: + if is_masked_observations: img, x_obs, x_obs_mask = data loss = self.model(img, x_obs=x_obs, x_obs_mask=x_obs_mask) else: @@ -1098,7 +1099,6 @@ def train(self): if self.step != 0 and divisible_by(self.step, self.save_and_sample_every): self.ema.ema_model.eval() - is_masked_observations = self.ema.ema_model.masked_observations with torch.inference_mode(): milestone = self.step // self.save_and_sample_every batches = num_to_groups(self.num_samples, self.batch_size)