-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
executable file
·34 lines (30 loc) · 2.12 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import argparse
def load_config():
# args
parser = argparse.ArgumentParser(description="PialNN")
# data
parser.add_argument('--data_path', default="./data/train/", type=str, help="path of the dataset")
parser.add_argument('--hemisphere', default="lh", type=str, help="left or right hemisphere (lh or rh)")
# model
parser.add_argument('--nc', default=128, type=int, help="num of channels")
parser.add_argument('--K', default=5, type=int, help="kernal size")
parser.add_argument('--n_scale', default=3, type=int, help="num of scales for image pyramid")
parser.add_argument('--n_smooth', default=1, type=int, help="num of Laplacian smoothing layers")
parser.add_argument('--lambd', default=1.0, type=float, help="Laplacian smoothing weights")
# training
parser.add_argument('--train_data_ratio', default=0.8, type=float, help="percentage of training data")
parser.add_argument('--lr', default=1e-4, type=float, help="learning rate")
parser.add_argument('--n_epoch', default=200, type=int, help="total training epochs")
parser.add_argument('--ckpts_interval', default=10, type=int, help="save checkpoints after each n epoch")
parser.add_argument('--report_training_loss', default=True, type=bool, help="if report training loss")
parser.add_argument('--save_model', default=True, type=bool, help="if save training models")
parser.add_argument('--save_mesh_train', default=False, type=bool, help="if save mesh during training")
# evaluation
parser.add_argument('--save_mesh_eval', default=False, type=bool, help="if save mesh during evaluation")
parser.add_argument('--n_test_pts', default=150000, type=int, help="num of points sampled for evaluation")
parser.add_argument('--gnn_layers', default=2, type=int, help="number of layers in the gnn")
parser.add_argument('--gnnVersion', default=1, type=int, help="0 for gcn, 1 for gat")
parser.add_argument('--cortexGNN', default=False, type=bool, help="Train with cortexGNN")
parser.add_argument('--model_location', default="na", type=str, help="location of model")
config = parser.parse_args()
return config