-
Notifications
You must be signed in to change notification settings - Fork 0
/
complete_training.py
58 lines (46 loc) · 1.87 KB
/
complete_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
from models import LstmModel
from pytorch_lightning import Trainer
from argparse import ArgumentParser
from data import AudioDataModule
from utils import latest_checkpoint_from_folder
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
def compress_and_train(args):
dev = args.device_name
hs = args.hidden_size
prune_pct = args.prune_pct
prune_iter = args.prune_iter
base_dir = f"{dev}-LSTM{hs}-prune{prune_pct}"
if args.fully_train_first:
base_dir += "-ftf"
base_dir = os.path.join(base_dir, str(prune_iter))
ckpt_path = latest_checkpoint_from_folder(base_dir)
data = AudioDataModule.from_argparse_args(args)
callbacks = [
ModelCheckpoint(monitor="val_loss", save_top_k=1, save_last=True),
#EarlyStopping(monitor="val_loss", patience=args.early_stopping_patience),
]
dir_path = os.path.join(base_dir, "completed")
trainer = Trainer(
default_root_dir=dir_path,
gpus=args.num_gpus,
check_val_every_n_epoch=5,
enable_progress_bar=False,
num_sanity_val_steps=0,
callbacks=callbacks,
max_epochs=750,
)
model = LstmModel.load_from_checkpoint(ckpt_path)
trainer.fit(model, data)
trainer.test(model, data, ckpt_path="best")
if __name__ == "__main__":
parser = ArgumentParser()
AudioDataModule.add_argparse_args(parser)
parser.add_argument("--prune_pct", type=int, default=30)
parser.add_argument("--prune_iter", type=int, default=0)
parser.add_argument("--hidden_size", type=int, default=64)
parser.add_argument("--early_stopping_patience", type=int, default=25)
parser.add_argument("--num_gpus", type=int, default=1)
parser.add_argument("--fully_train_first", action="store_true")
compress_and_train(parser.parse_args())