Skip to content

Commit

Permalink
Merge pull request #296 from mehdidc/resuming
Browse files Browse the repository at this point in the history
Save/Resume optimizer state, scheduler state, and epoch
  • Loading branch information
lucidrains authored Jun 16, 2021
2 parents 50fb971 + 7c631fb commit d6107cc
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def cp_path_to_dir(cp_path, tag):
tokenizer = ChineseTokenizer()

# reconstitute vae

if RESUME:
dalle_path = Path(DALLE_PATH)
if using_deepspeed:
Expand All @@ -199,6 +198,8 @@ def cp_path_to_dir(cp_path, tag):
loaded_obj = torch.load(str(dalle_path), map_location='cpu')

dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj['vae_params'], loaded_obj['weights']
opt_state = loaded_obj.get('opt_state')
scheduler_state = loaded_obj.get('scheduler_state')

if vae_params is not None:
vae = DiscreteVAE(**vae_params)
Expand All @@ -212,6 +213,7 @@ def cp_path_to_dir(cp_path, tag):
**dalle_params
)
IMAGE_SIZE = vae.image_size
resume_epoch = loaded_obj.get('epoch', 0)
else:
if exists(VAE_PATH):
vae_path = Path(VAE_PATH)
Expand Down Expand Up @@ -253,6 +255,7 @@ def cp_path_to_dir(cp_path, tag):
ff_dropout=FF_DROPOUT,
attn_dropout=ATTN_DROPOUT,
)
resume_epoch = 0

# configure OpenAI VAE for float16s

Expand Down Expand Up @@ -320,6 +323,8 @@ def group_weight(model):
# optimizer

opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE)
if RESUME and opt_state:
opt.load_state_dict(opt_state)

if LR_DECAY:
scheduler = ReduceLROnPlateau(
Expand All @@ -331,6 +336,10 @@ def group_weight(model):
min_lr=1e-6,
verbose=True,
)
if RESUME and scheduler_state:
scheduler.load_state_dict(scheduler_state)
else:
scheduler = None

if distr_backend.is_root_worker():
# experiment tracker
Expand Down Expand Up @@ -398,10 +407,11 @@ def group_weight(model):
distr_dalle.load_checkpoint(str(cp_dir))


def save_model(path):
def save_model(path, epoch=0):
save_obj = {
'hparams': dalle_params,
'vae_params': vae_params,
'epoch': epoch,
}
if using_deepspeed:
cp_dir = cp_path_to_dir(path, 'ds')
Expand Down Expand Up @@ -436,18 +446,18 @@ def save_model(path):

save_obj = {
**save_obj,
'weights': dalle.state_dict()
'weights': dalle.state_dict(),
'opt_state': opt.state_dict(),
}

save_obj['scheduler_state'] = (scheduler.state_dict() if scheduler else None)
torch.save(save_obj, path)

# training

# Saves a checkpoint before training begins to fail early when mis-configured.
# See https://github.com/lucidrains/DALLE-pytorch/wiki/DeepSpeed-Checkpoints
save_model(DALLE_OUTPUT_FILE_NAME)

for epoch in range(EPOCHS):
save_model(DALLE_OUTPUT_FILE_NAME, epoch=resume_epoch)
for epoch in range(resume_epoch, EPOCHS):
if data_sampler:
data_sampler.set_epoch(epoch)
for i, (text, images) in enumerate(distr_dl):
Expand Down Expand Up @@ -485,7 +495,7 @@ def save_model(path):
}

if i % SAVE_EVERY_N_STEPS == 0:
save_model(DALLE_OUTPUT_FILE_NAME)
save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)

if i % 100 == 0:
if distr_backend.is_root_worker():
Expand Down Expand Up @@ -518,7 +528,7 @@ def save_model(path):
if LR_DECAY:
distr_scheduler.step(avg_loss)

save_model(DALLE_OUTPUT_FILE_NAME)
save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)

if distr_backend.is_root_worker():
# save trained model to wandb as an artifact every epoch's end
Expand All @@ -527,7 +537,7 @@ def save_model(path):
model_artifact.add_file(DALLE_OUTPUT_FILE_NAME)
run.log_artifact(model_artifact)

save_model(DALLE_OUTPUT_FILE_NAME)
save_model(DALLE_OUTPUT_FILE_NAME, epoch=epoch)
if distr_backend.is_root_worker():
wandb.save(DALLE_OUTPUT_FILE_NAME)
model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config))
Expand Down

0 comments on commit d6107cc

Please sign in to comment.