-
Notifications
You must be signed in to change notification settings - Fork 4
/
trainer.py
165 lines (140 loc) · 6.81 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import glob
import os
from datetime import datetime
import meters
import torch
import utils
from dataloaders import get_data_loaders
class Trainer:
def __init__(self, cfgs, model):
self.device = cfgs.get("device", "cpu")
self.num_epochs = cfgs.get("num_epochs", 30)
self.batch_size = cfgs.get("batch_size", 64)
self.checkpoint_dir = cfgs.get("checkpoint_dir", "results")
self.save_checkpoint_freq = cfgs.get("save_checkpoint_freq", 1)
self.keep_num_checkpoint = cfgs.get("keep_num_checkpoint", 2) # -1 for keeping all checkpoints
self.resume = cfgs.get("resume", True)
self.run_finetune = cfgs.get("run_finetune", False)
self.use_logger = cfgs.get("use_logger", True)
self.log_freq = cfgs.get("log_freq", 100)
self.archive_code = cfgs.get("archive_code", True)
self.checkpoint_name = cfgs.get("checkpoint_name", None)
self.test_result_dir = cfgs.get("test_result_dir", None)
self.cfgs = cfgs
self.metrics_trace = meters.MetricsTrace()
self.make_metrics = lambda m=None: meters.StandardMetrics(m)
self.model = model(cfgs)
self.model.trainer = self
self.train_loader, self.val_loader, self.test_loader = get_data_loaders(cfgs)
def load_checkpoint(self, optim=True):
"""Search the specified/latest checkpoint in checkpoint_dir and load the model and optimizer."""
if self.checkpoint_name is not None:
checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name)
else:
checkpoints = sorted(glob.glob(os.path.join(self.checkpoint_dir, "*.pth")))
if len(checkpoints) == 0:
return 0
checkpoint_path = checkpoints[-1]
self.checkpoint_name = os.path.basename(checkpoint_path)
print(f"Loading checkpoint from {checkpoint_path}")
cp = torch.load(checkpoint_path, map_location=self.device)
self.model.load_model_state(cp)
if optim:
self.model.load_optimizer_state(cp)
self.metrics_trace = cp["metrics_trace"]
epoch = cp["epoch"]
return epoch
def save_checkpoint(self, epoch, optim=True):
"""Save model, optimizer, and metrics state to a checkpoint in checkpoint_dir for the specified epoch."""
utils.xmkdir(self.checkpoint_dir)
checkpoint_path = os.path.join(self.checkpoint_dir, f"checkpoint{epoch:03}.pth")
state_dict = self.model.get_model_state()
if optim:
optimizer_state = self.model.get_optimizer_state()
state_dict = {**state_dict, **optimizer_state}
state_dict["metrics_trace"] = self.metrics_trace
state_dict["epoch"] = epoch
print(f"Saving checkpoint to {checkpoint_path}")
torch.save(state_dict, checkpoint_path)
if self.keep_num_checkpoint > 0:
utils.clean_checkpoint(self.checkpoint_dir, keep_num=self.keep_num_checkpoint)
def save_clean_checkpoint(self, path):
"""Save model state only to specified path."""
torch.save(self.model.get_model_state(), path)
def test(self):
"""Perform testing."""
self.model.to_device(self.device)
self.current_epoch = self.load_checkpoint(optim=False)
if self.test_result_dir is None:
self.test_result_dir = os.path.join(
self.checkpoint_dir, f"test_results_{self.checkpoint_name}".replace(".pth", "")
)
print(f"Saving testing results to {self.test_result_dir}")
with torch.no_grad():
m = self.run_epoch(self.test_loader, epoch=self.current_epoch, is_test=True)
score_path = os.path.join(self.test_result_dir, "eval_scores.txt")
self.model.save_scores(score_path)
def train(self):
"""Perform training."""
# archive code and configs
if self.archive_code:
utils.archive_code(os.path.join(self.checkpoint_dir, "archived_code.zip"), filetypes=[".py", ".yml"])
utils.dump_yaml(os.path.join(self.checkpoint_dir, "configs.yml"), self.cfgs)
# initialize
start_epoch = 0
self.metrics_trace.reset()
self.train_iter_per_epoch = len(self.train_loader)
self.model.to_device(self.device)
self.model.init_optimizers()
# resume from checkpoint
if self.run_finetune:
self.load_checkpoint(optim=False)
elif self.resume:
start_epoch = self.load_checkpoint(optim=True)
# initialize tensorboardX logger
if self.use_logger:
from tensorboardX import SummaryWriter
self.logger = SummaryWriter(
os.path.join(self.checkpoint_dir, "logs", datetime.now().strftime("%Y%m%d-%H%M%S"))
)
# cache one batch for visualization
self.viz_input = self.val_loader.__iter__().__next__()
# run epochs
print(f"{self.model.model_name}: optimizing to {self.num_epochs} epochs")
for epoch in range(start_epoch, self.num_epochs):
self.current_epoch = epoch
metrics = self.run_epoch(self.train_loader, epoch)
self.metrics_trace.append("train", metrics)
with torch.no_grad():
metrics = self.run_epoch(self.val_loader, epoch, is_validation=True)
self.metrics_trace.append("val", metrics)
if (epoch + 1) % self.save_checkpoint_freq == 0:
self.save_checkpoint(epoch + 1, optim=True)
self.metrics_trace.plot(pdf_path=os.path.join(self.checkpoint_dir, "metrics.pdf"))
self.metrics_trace.save(os.path.join(self.checkpoint_dir, "metrics.json"))
print(f"Training completed after {epoch+1} epochs.")
def run_epoch(self, loader, epoch=0, is_validation=False, is_test=False):
"""Run one epoch."""
is_train = not is_validation and not is_test
metrics = self.make_metrics()
if is_train:
print(f"Starting training epoch {epoch}")
self.model.set_train()
else:
print(f"Starting validation epoch {epoch}")
self.model.set_eval()
for iter, input in enumerate(loader):
m = self.model.forward(input)
if is_train:
self.model.backward()
elif is_test:
self.model.save_results(self.test_result_dir)
metrics.update(m, self.batch_size)
print(f"{'T' if is_train else 'V'}{epoch:02}/{iter:05}/{metrics}")
if self.use_logger and is_train:
total_iter = iter + epoch * self.train_iter_per_epoch
if total_iter % self.log_freq == 0:
with torch.no_grad():
self.model.forward(self.viz_input)
self.model.visualize(self.logger, total_iter=total_iter, max_bs=25)
return metrics