-
Notifications
You must be signed in to change notification settings - Fork 107
/
Copy pathtrain_net.py
57 lines (43 loc) · 1.9 KB
/
train_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from lib.config import cfg, args
from lib.networks import make_network
from lib.train import make_trainer, make_optimizer, make_lr_scheduler, make_recorder, set_lr_scheduler
from lib.datasets import make_data_loader
from lib.utils.net_utils import load_model, save_model, load_network
from lib.evaluators import make_evaluator
import torch.multiprocessing
def train(cfg, network):
if cfg.train.dataset[:4] != 'City':
torch.multiprocessing.set_sharing_strategy('file_system')
trainer = make_trainer(cfg, network)
optimizer = make_optimizer(cfg, network)
scheduler = make_lr_scheduler(cfg, optimizer)
recorder = make_recorder(cfg)
evaluator = make_evaluator(cfg)
begin_epoch = load_model(network, optimizer, scheduler, recorder, cfg.model_dir, resume=cfg.resume)
# set_lr_scheduler(cfg, scheduler)
train_loader = make_data_loader(cfg, is_train=True, max_iter=cfg.ep_iter)
val_loader = make_data_loader(cfg, is_train=False)
# train_loader = make_data_loader(cfg, is_train=True, max_iter=100)
for epoch in range(begin_epoch, cfg.train.epoch):
recorder.epoch = epoch
trainer.train(epoch, train_loader, optimizer, recorder)
scheduler.step()
if (epoch + 1) % cfg.save_ep == 0:
save_model(network, optimizer, scheduler, recorder, epoch, cfg.model_dir)
if (epoch + 1) % cfg.eval_ep == 0:
trainer.val(epoch, val_loader, evaluator, recorder)
return network
def test(cfg, network):
trainer = make_trainer(cfg, network)
val_loader = make_data_loader(cfg, is_train=False)
evaluator = make_evaluator(cfg)
epoch = load_network(network, cfg.model_dir, resume=cfg.resume, epoch=cfg.test.epoch)
trainer.val(epoch, val_loader, evaluator)
def main():
network = make_network(cfg)
if args.test:
test(cfg, network)
else:
train(cfg, network)
if __name__ == "__main__":
main()