Skip to content

Commit

Permalink
save samples and models to ./results path
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 10, 2020
1 parent ff451f6 commit d4ce9f6
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
UPDATE_EMA_EVERY = 10
EXTS = ['jpg', 'jpeg', 'png']

RESULTS_FOLDER = Path('./results')
RESULTS_FOLDER.mkdir(exist_ok = True)

# helpers functions

def exists(x):
Expand Down Expand Up @@ -497,10 +500,10 @@ def save(self, milestone):
'model': self.model.state_dict(),
'ema': self.ema_model.state_dict()
}
torch.save(data, f'./model-{milestone}.pt')
torch.save(data, str(RESULTS_FOLDER / f'model-{milestone}.pt'))

def load(self, milestone):
data = torch.load(f'./model-{milestone}.pt')
data = torch.load(str(RESULTS_FOLDER / f'model-{milestone}.pt'))

self.step = data['step']
self.model.load_state_dict(data['model'])
Expand All @@ -527,7 +530,7 @@ def train(self):
batches = num_to_groups(36, self.batch_size)
all_images_list = list(map(lambda n: self.ema_model.sample(self.image_size, batch_size=n), batches))
all_images = torch.cat(all_images_list, dim=0)
utils.save_image(all_images, f'./sample-{milestone}.png', nrow=6)
utils.save_image(all_images, str(RESULTS_FOLDER / f'sample-{milestone}.png'), nrow=6)
self.save(milestone)

self.step += 1
Expand Down
Binary file removed sample.png
Binary file not shown.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.5.0',
version = '0.5.2',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d4ce9f6

Please sign in to comment.