diff --git a/train.py b/train.py index b9ca1017..8fad81da 100755 --- a/train.py +++ b/train.py @@ -380,16 +380,16 @@ def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, devic if args.ckpt is not None: print('load model:', args.ckpt) - + ckpt = torch.load(args.ckpt) try: ckpt_name = os.path.basename(args.ckpt) args.start_iter = int(os.path.splitext(ckpt_name)[0]) - + except ValueError: pass - + generator.load_state_dict(ckpt['g']) discriminator.load_state_dict(ckpt['d']) g_ema.load_state_dict(ckpt['g_ema']) @@ -397,6 +397,9 @@ def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema, devic g_optim.load_state_dict(ckpt['g_optim']) d_optim.load_state_dict(ckpt['d_optim']) + del ckpt + torch.cuda.empty_cache() + if args.distributed: generator = nn.parallel.DistributedDataParallel( generator,