diff --git a/train.py b/train.py index 5995df6..3151f33 100755 --- a/train.py +++ b/train.py @@ -50,7 +50,7 @@ def requires_grad(model, flag=True): for p in model.parameters(): p.requires_grad = flag - +@torch.no_grad() def accumulate(model1, model2, decay=0.999): par1 = dict(model1.named_parameters()) par2 = dict(model2.named_parameters()) @@ -293,9 +293,11 @@ def train(conf, loader, generator, discriminator, g_optim, d_optim, g_ema, devic ) if i % 100 == 0: + generator.zero_grad() + discriminator.zero_grad() with torch.no_grad(): g_ema.eval() - sample = g_ema(sample_z) + sample = g_ema(sample_z).cpu() utils.save_image( sample, f"sample/{str(i).zfill(6)}.png", @@ -303,6 +305,7 @@ def train(conf, loader, generator, discriminator, g_optim, d_optim, g_ema, devic normalize=True, value_range=(-1, 1), ) + sample = None # cleanup memory if i % 10000 == 0: torch.save(