-
Notifications
You must be signed in to change notification settings - Fork 80
/
checkpoint.py
118 lines (100 loc) · 4.24 KB
/
checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
import glob
import os
import torch
def prepare_checkpoint_metrics(outputs, factor):
return {"Loss": outputs.div(factor).mean().item()}
def save_model(config, model, optimizer, epoch, metrics=None, scheduler=None):
if config.checkpoint_dir:
abs_pathd = os.path.abspath(config.checkpoint_dir)
os.makedirs(abs_pathd, exist_ok=True)
filename = f"CLIP_epoch_{epoch}.pt"
save_path = os.path.join(abs_pathd, filename)
model_state = model.state_dict()
optimizer_state = optimizer.state_dict()
scheduler_state = scheduler.state_dict()
torch.save(
{
"epoch": epoch,
"model_state_dict": model_state,
"optimizer_state_dict": optimizer_state,
"metrics": metrics,
"config": config,
"scheduler_state_dict": scheduler_state,
},
save_path,
)
return save_path
def checkpoints_exist(path, config=None, inverse=None):
if os.path.exists(path):
if config is not None:
prefix = ""
files = glob.glob(f"{os.path.join(path, prefix + '*')}")
if inverse is not None:
files_all = glob.glob(f"{os.path.join(path, '*.pt')}")
files = list(set(files_all) - set(files))
else:
# All checkpoint files
files = glob.glob(f"{os.path.join(path, '*.pt')}")
if len(files) > 0:
return True
return False
def get_latest_filepath(path, config):
prefix = ""
files = glob.iglob(f"{os.path.join(path, prefix + '*')}")
latest_file = max(files, key=os.path.getctime)
return os.path.join(path, latest_file)
def load_from_file_passing_constraints(config):
if config.checkpoint_file:
abs_path_ckpt = os.path.abspath(config.checkpoint_file)
# Check save constraints for preventing overwrite
if config.checkpoint_dir and checkpoints_exist(os.path.abspath(config.checkpoint_dir)):
raise RuntimeError(
"Found previously saved checkpoint(s) at checkpoint-dir. "
"Overwriting them with checkpoints building on checkpoint-file "
"is not supported. Please specify a different checkpoint-dir to "
"save checkpoints from this run."
)
# Return checkpoint if valid
if os.path.isfile(abs_path_ckpt):
try:
checkpoint = torch.load(abs_path_ckpt)
return checkpoint
except Exception as e:
print(f"Failed with exception {e}.")
else:
raise RuntimeError("Please specify a PyTorch checkpoint file.")
return None
def load_from_dir_passing_constraints(config):
# Latest checkpoint at checkpoint_dir
if config.checkpoint_dir:
abs_pathd = os.path.abspath(config.checkpoint_dir)
print("abs_pathd: ", abs_pathd)
# Check save constraints for resuming run without overwrite
if checkpoints_exist(abs_pathd, config, inverse=True):
raise RuntimeError(
"Please specify a different checkpoint-dir. " "This one has checkpoints from another incompatible run."
)
if checkpoints_exist(abs_pathd, config):
if config.restore_epochs_and_optimizer:
abs_path_ckpt = get_latest_filepath(abs_pathd, config)
try:
checkpoint = torch.load(abs_path_ckpt)
return checkpoint
except Exception as e:
print(f"Failed with exception {e}.")
else:
# Overwrite prevention
raise RuntimeError(
"Please restore full state to continue training. "
"Alternatively, specify a different checkpoint-dir "
"and a checkpoint-file from which to restore "
"(only model) state and retrain."
)
return None
def load_checkpoint_passing_constraints(config):
if config.checkpoint_file:
checkpoint = load_from_file_passing_constraints(config)
else:
checkpoint = load_from_dir_passing_constraints(config)
return checkpoint