-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_model_LUMIN_pretrain.py
109 lines (88 loc) · 3.47 KB
/
train_model_LUMIN_pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""Script for training MF-U-Net using PyTorch Lightning API."""
import sys
from pathlib import Path
import argparse
import random
import numpy as np
from utils.config import load_config
from utils.logging import setup_logging
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import (
EarlyStopping,
ModelCheckpoint,
LearningRateMonitor,
DeviceStatsMonitor,
)
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.profilers import PyTorchProfiler
from datamodules import SHMUDataModule
from models import LUMIN
from models import MFUNET
import wandb
import os
def main(configpath, checkpoint=None):
confpath = Path("config") / configpath
dsconf = load_config(confpath / "datasets.yaml")
outputconf = load_config(confpath / "output.yaml")
modelconf = load_config(confpath / "model.yaml")
modelconf_mf = load_config("/data/softec-lagrangian-nowcasting/configs/MFUNET_rmse/model.yaml")
torch.manual_seed(1)
random.seed(1)
np.random.seed(1)
torch.set_float32_matmul_precision('high')
setup_logging(outputconf.logging)
datamodel = SHMUDataModule(dsconf, modelconf.train_params)
model_mf = MFUNET(modelconf_mf).load_from_checkpoint("/data/softec-lagrangian-nowcasting/checkpoints/mfunet-rmse/epoch=6-step=2065.ckpt", config=modelconf_mf)
model = LUMIN(modelconf)
model.mfunet_network.load_state_dict(model_mf.network.state_dict())
for param in model.mfunet_network.parameters():
param.requires_grad = False
del model_mf
# Callbacks
model_ckpt = ModelCheckpoint(
dirpath=f"checkpoints/{modelconf.train_params.savefile}",
save_top_k=1,
monitor="val_loss",
save_on_train_epoch_end=False,
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
early_stopping = EarlyStopping(**modelconf.train_params.early_stopping)
device_monitor = DeviceStatsMonitor()
logger = WandbLogger(save_dir=f"checkpoints/{modelconf.train_params.savefile}/wandb", project=modelconf.train_params.savefile, log_model=True)
profiler = PyTorchProfiler(profile_memory=False)
wandb.run.save(os.path.join(configpath, '*'), policy='now')
trainer = pl.Trainer(
profiler=profiler,
logger=logger,
val_check_interval=modelconf.train_params.val_check_interval,
max_epochs=modelconf.train_params.max_epochs,
max_time=modelconf.train_params.max_time,
devices=modelconf.train_params.gpus,
limit_val_batches=modelconf.train_params.val_batches,
limit_train_batches=modelconf.train_params.train_batches,
callbacks=[
early_stopping,
model_ckpt,
lr_monitor,
device_monitor,
],
log_every_n_steps=1,
)
trainer.fit(model=model, datamodule=datamodel, ckpt_path=checkpoint)
torch.save(model.state_dict(), f"state_dict_{modelconf.train_params.savefile}.ckpt")
trainer.save_checkpoint(f"{modelconf.train_params.savefile}.ckpt")
if __name__ == "__main__":
argparser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
argparser.add_argument("config", type=str, help="Configuration folder")
argparser.add_argument(
"-c",
"--continue_training",
type=str,
default=None,
help="Path to checkpoint for model that is continued.",
)
args = argparser.parse_args()
main(args.config, args.continue_training)