-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_model.py
157 lines (133 loc) · 4.22 KB
/
run_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse
import random
import tensorflow as tf
from tensorflow.keras import optimizers
from src import dry_run, env
from src.model import (
autoencoder,
base,
clearsky,
conv2d,
conv3d,
conv3d_lm,
conv3d_tran,
embed_conv3d,
gru,
conv2d_mathe,
)
from src.session import Session
MODELS = {
autoencoder.NAME_AUTOENCODER: autoencoder.Autoencoder,
conv2d.NAME: conv2d.CNN2D,
conv2d.NAME_CLEARSKY: conv2d.CNN2DClearsky,
conv3d.NAME: conv3d.CNN3D,
conv3d_tran.NAME: conv3d_tran.CNN3DTranClearsky,
embed_conv3d.NAME: embed_conv3d.Conv3D,
conv3d_lm.NAME: conv3d_lm.Conv3D,
clearsky.NAME: clearsky.Clearsky,
clearsky.NAME_MLP: clearsky.ClearskyMLP,
gru.NAME: gru.GRU,
conv2d_mathe.NAME_CLEARSKY_MATHE: conv2d_mathe.Conv2DMatheClearsky,
}
def create_model(model_name: str) -> base.Model:
"""Create the model from its name."""
try:
return MODELS[model_name]()
except KeyError:
raise ValueError(
f"Bad model name, {model_name} do not exist.\n"
+ f"Available models are {list(MODELS.keys())}"
)
def parse_args():
"""Parse the user's arguments.
The default arguments are to be used in order to reproduce
the original experiments.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--cache_file",
help="Tensorflow caching apply after model's preprocessing."
+ "Note that this cache must be used only for a model with a specific configuration."
+ "It must not be shared between models or the same model with different configuration.",
type=str,
default=None,
)
parser.add_argument(
"--run_local", help="Enable training with relative paths", action="store_true"
)
parser.add_argument(
"--epochs", help="Number of epoch to train", default=25, type=int
)
parser.add_argument(
"--test",
help="Test a trained model on the test set. The value must be the model's checkpoint",
default=None,
type=str,
)
parser.add_argument(
"--train", help="Train a model.", action="store_true",
)
parser.add_argument(
"--dry_run",
help="No training, no tensorflow, just the generator",
action="store_true",
)
parser.add_argument(
"--skip_non_cached",
help="Skip images which are not already cached in the image reader",
action="store_true",
)
parser.add_argument(
"--seed", help="Seed for the experiment", default=1234, type=int
)
parser.add_argument(
"--random_seed",
help="Will overide the default seed and use a random one",
action="store_true",
)
parser.add_argument(
"--no_checkpoint", help="Will not save any checkpoints", action="store_true",
)
parser.add_argument(
"--checkpoint",
help="The checkpoint to load before training.",
default=None,
type=str,
)
parser.add_argument("--lr", help="Learning rate", default=0.001, type=float)
parser.add_argument(
"--model",
help=f"Name of the model to train, available models are:\n{list(MODELS.keys())}",
type=str,
required=True,
)
parser.add_argument("--batch_size", help="Batch size", default=128, type=int)
return parser.parse_args()
def run(args):
"""Run the model with RMSE Loss.
It can train or test with different datasets.
"""
env.run_local = args.run_local
if not args.random_seed:
random.seed(args.seed)
tf.random.set_seed(args.seed)
if args.dry_run:
dry_run.run(args.enable_tf_caching, args.skip_non_cached)
model = create_model(args.model)
session = Session(
model=model, batch_size=args.batch_size, skip_non_cached=args.skip_non_cached,
)
if args.train:
optimizer = optimizers.Adam(args.lr)
session.train(
optimizer=optimizer,
cache_file=args.cache_file,
enable_checkpoint=not args.no_checkpoint,
epochs=args.epochs,
checkpoint=args.checkpoint,
)
if args.test is not None:
session.test(args.test)
if __name__ == "__main__":
args = parse_args()
run(args)