-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
executable file
·120 lines (101 loc) · 3.53 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
MT3 baseline training.
To use random order, use `dataset.dataset_2_random`. Or else, use `dataset.dataset_2`.
"""
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
import torch
import pytorch_lightning as pl
import os
import hydra
from tasks.mt3_net import MT3Net
@hydra.main(config_path="config", config_name="config")
# def main(config, model_config, result_dir, mode, path):
def main(cfg):
# set seed to ensure reproducibility
pl.seed_everything(cfg.seed)
model = hydra.utils.instantiate(
cfg.model,
optim_cfg=cfg.optim,
eval_cfg=cfg.eval
)
logger = TensorBoardLogger(save_dir='.',
name=f"{cfg.model_type}_{cfg.dataset_type}")
# sanity check to make sure the correct model is used
assert cfg.model_type == cfg.model._target_.split('.')[-1]
lr_monitor = LearningRateMonitor(logging_interval='step')
checkpoint_callback = ModelCheckpoint(**cfg.modelcheckpoint)
tqdm_callback = TQDMProgressBar(refresh_rate=1)
trainer = pl.Trainer(
logger=logger,
callbacks=[lr_monitor, checkpoint_callback, tqdm_callback],
**cfg.trainer
)
train_loader = DataLoader(
hydra.utils.instantiate(cfg.dataset.train),
# SlakhDataset(**cfg.data.train),
**cfg.dataloader.train,
collate_fn=hydra.utils.get_method(cfg.dataset.collate_fn)
)
val_loader = DataLoader(
hydra.utils.instantiate(cfg.dataset.val),
**cfg.dataloader.val,
collate_fn=hydra.utils.get_method(cfg.dataset.collate_fn)
)
if cfg.path is not None and cfg.path != "":
if cfg.path.endswith(".ckpt"):
print(f"Validating on {cfg.path}...")
trainer.validate(
model,
val_loader,
ckpt_path=cfg.path
)
print("Training start...")
trainer.fit(
model,
train_loader,
val_loader,
ckpt_path=cfg.path
)
elif cfg.path.endswith(".pth"):
print(f"Loading weights from {cfg.path}...")
model.model.load_state_dict(
torch.load(cfg.path),
strict=False
)
trainer.validate(
model,
val_loader,
)
print("Training start...")
trainer.fit(
model,
train_loader,
val_loader,
)
else:
raise ValueError(f"Invalid extension for path: {cfg.path}")
else:
trainer.fit(
model,
train_loader,
val_loader,
)
# save the model in .pt format
current_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
ckpt_path = os.path.join(current_dir, f"{cfg.model_type}_{cfg.dataset_type}", "version_0/checkpoints/last.ckpt")
model.eval()
dic = {}
for key in model.state_dict():
if "model." in key:
dic[key.replace("model.", "")] = model.state_dict()[key]
else:
dic[key] = model.state_dict()[key]
torch.save(dic, ckpt_path.replace(".ckpt", ".pt"))
print(f"Saved model in {ckpt_path.replace('.ckpt', '.pt')}.")
if __name__ == "__main__":
main()