-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
92 lines (76 loc) · 2.78 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argparse
from pathlib import Path
import numpy as np
import glob
from datasets import DataInterface
from models import ModelInterface
from utils.utils import *
# pytorch_lightning
import pytorch_lightning as pl
from pytorch_lightning import Trainer
#--->Setting parameters
def make_parse():
parser = argparse.ArgumentParser()
parser.add_argument('--stage', default='train', type=str)
parser.add_argument('--config', default='Camelyon/TransMIL.yaml',type=str)
parser.add_argument('--gpus', default = [2])
parser.add_argument('--fold', default = 0)
args = parser.parse_args()
return args
#---->main
def main(cfg):
#---->Initialize seed
pl.seed_everything(cfg.General.seed)
#---->load loggers
cfg.load_loggers = load_loggers(cfg)
#---->load callbacks
cfg.callbacks = load_callbacks(cfg)
#---->Define Data
DataInterface_dict = {'train_batch_size': cfg.Data.train_dataloader.batch_size,
'train_num_workers': cfg.Data.train_dataloader.num_workers,
'test_batch_size': cfg.Data.test_dataloader.batch_size,
'test_num_workers': cfg.Data.test_dataloader.num_workers,
'dataset_name': cfg.Data.dataset_name,
'dataset_cfg': cfg.Data,}
dm = DataInterface(**DataInterface_dict)
#---->Define Model
ModelInterface_dict = {'model': cfg.Model,
'loss': cfg.Loss,
'optimizer': cfg.Optimizer,
'data': cfg.Data,
'log': cfg.log_path
}
model = ModelInterface(**ModelInterface_dict)
#---->Instantiate Trainer
trainer = Trainer(
num_sanity_val_steps=0,
logger=cfg.load_loggers,
callbacks=cfg.callbacks,
max_epochs= cfg.General.epochs,
gpus=cfg.General.gpus,
amp_level=cfg.General.amp_level,
precision=cfg.General.precision,
accumulate_grad_batches=cfg.General.grad_acc,
deterministic=True,
check_val_every_n_epoch=1,
)
#---->train or test
if cfg.General.server == 'train':
trainer.fit(model = model, datamodule = dm)
else:
model_paths = list(cfg.log_path.glob('*.ckpt'))
model_paths = [str(model_path) for model_path in model_paths if 'epoch' in str(model_path)]
for path in model_paths:
print(path)
new_model = model.load_from_checkpoint(checkpoint_path=path, cfg=cfg)
trainer.test(model=new_model, datamodule=dm)
if __name__ == '__main__':
args = make_parse()
cfg = read_yaml(args.config)
#---->update
cfg.config = args.config
cfg.General.gpus = args.gpus
cfg.General.server = args.stage
cfg.Data.fold = args.fold
#---->main
main(cfg)