Skip to content

Commit

Permalink
make opt state/scheduler state/epoch non mandatory to be compatible w…
Browse files Browse the repository at this point in the history
…ith current behavior
  • Loading branch information
mehdidc committed Jun 12, 2021
1 parent d48ea80 commit dda155f
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ def cp_path_to_dir(cp_path, tag):
assert dalle_path.exists(), 'DALL-E model file does not exist'
loaded_obj = torch.load(str(dalle_path), map_location='cpu')

dalle_params, vae_params, weights, opt_state, scheduler_state = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights'], loaded_obj['opt_state'], loaded_obj['scheduler_state']
dalle_params, vae_params, weights= loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights']
opt_state = loaded_obj.get('opt_state')
scheduler_state = loaded_obj.get('scheduler_state')

if vae_params is not None:
vae = DiscreteVAE(**vae_params)
Expand All @@ -190,7 +192,7 @@ def cp_path_to_dir(cp_path, tag):
**dalle_params
)
IMAGE_SIZE = vae.image_size
resume_epoch = loaded_obj['epoch']
resume_epoch = loaded_obj.get('epoch', 0)
else:
if exists(VAE_PATH):
vae_path = Path(VAE_PATH)
Expand Down Expand Up @@ -296,7 +298,7 @@ def group_weight(model):
# optimizer

opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE)
if RESUME:
if RESUME and opt_state:
opt.load_state_dict(opt_state)

if LR_DECAY:
Expand All @@ -309,7 +311,7 @@ def group_weight(model):
min_lr=1e-6,
verbose=True,
)
if RESUME:
if RESUME and scheduler_state:
scheduler.load_state_dict(scheduler_state)
else:
scheduler = None
Expand Down

0 comments on commit dda155f

Please sign in to comment.