diff --git a/train_dalle.py b/train_dalle.py index 08ac2005..1206474e 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -178,7 +178,9 @@ def cp_path_to_dir(cp_path, tag): assert dalle_path.exists(), 'DALL-E model file does not exist' loaded_obj = torch.load(str(dalle_path), map_location='cpu') - dalle_params, vae_params, weights, opt_state, scheduler_state = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights'], loaded_obj['opt_state'], loaded_obj['scheduler_state'] + dalle_params, vae_params, weights= loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights'] + opt_state = loaded_obj.get('opt_state') + scheduler_state = loaded_obj.get('scheduler_state') if vae_params is not None: vae = DiscreteVAE(**vae_params) @@ -190,7 +192,7 @@ def cp_path_to_dir(cp_path, tag): **dalle_params ) IMAGE_SIZE = vae.image_size - resume_epoch = loaded_obj['epoch'] + resume_epoch = loaded_obj.get('epoch', 0) else: if exists(VAE_PATH): vae_path = Path(VAE_PATH) @@ -296,7 +298,7 @@ def group_weight(model): # optimizer opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE) -if RESUME: +if RESUME and opt_state: opt.load_state_dict(opt_state) if LR_DECAY: @@ -309,7 +311,7 @@ def group_weight(model): min_lr=1e-6, verbose=True, ) - if RESUME: + if RESUME and scheduler_state: scheduler.load_state_dict(scheduler_state) else: scheduler = None