-
Notifications
You must be signed in to change notification settings - Fork 69
/
main_EUROC.py
120 lines (119 loc) · 3.86 KB
/
main_EUROC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import torch
import src.learning as lr
import src.networks as sn
import src.losses as sl
import src.dataset as ds
import numpy as np
base_dir = os.path.dirname(os.path.realpath(__file__))
data_dir = '/path/to/EUROC/dataset'
# test a given network
# address = os.path.join(base_dir, 'results/EUROC/2020_02_18_16_52_55/')
# or test the last trained network
address = "last"
################################################################################
# Network parameters
################################################################################
net_class = sn.GyroNet
net_params = {
'in_dim': 6,
'out_dim': 3,
'c0': 16,
'dropout': 0.1,
'ks': [7, 7, 7, 7],
'ds': [4, 4, 4],
'momentum': 0.1,
'gyro_std': [1*np.pi/180, 2*np.pi/180, 5*np.pi/180],
}
################################################################################
# Dataset parameters
################################################################################
dataset_class = ds.EUROCDataset
dataset_params = {
# where are raw data ?
'data_dir': data_dir,
# where record preloaded data ?
'predata_dir': os.path.join(base_dir, 'data/EUROC'),
# set train, val and test sequence
'train_seqs': [
'MH_01_easy',
'MH_03_medium',
'MH_05_difficult',
'V1_02_medium',
'V2_01_easy',
'V2_03_difficult'
],
'val_seqs': [
'MH_01_easy',
'MH_03_medium',
'MH_05_difficult',
'V1_02_medium',
'V2_01_easy',
'V2_03_difficult',
],
'test_seqs': [
'MH_02_easy',
'MH_04_difficult',
'V2_02_medium',
'V1_03_difficult',
'V1_01_easy',
],
# size of trajectory during training
'N': 32 * 500, # should be integer * 'max_train_freq'
'min_train_freq': 16,
'max_train_freq': 32,
}
################################################################################
# Training parameters
################################################################################
train_params = {
'optimizer_class': torch.optim.Adam,
'optimizer': {
'lr': 0.01,
'weight_decay': 1e-1,
'amsgrad': False,
},
'loss_class': sl.GyroLoss,
'loss': {
'min_N': int(np.log2(dataset_params['min_train_freq'])),
'max_N': int(np.log2(dataset_params['max_train_freq'])),
'w': 1e6,
'target': 'rotation matrix',
'huber': 0.005,
'dt': 0.005,
},
'scheduler_class': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
'scheduler': {
'T_0': 600,
'T_mult': 2,
'eta_min': 1e-3,
},
'dataloader': {
'batch_size': 10,
'pin_memory': False,
'num_workers': 0,
'shuffle': False,
},
# frequency of validation step
'freq_val': 600,
# total number of epochs
'n_epochs': 1800,
# where record results ?
'res_dir': os.path.join(base_dir, "results/EUROC"),
# where record Tensorboard log ?
'tb_dir': os.path.join(base_dir, "results/runs/EUROC"),
}
################################################################################
# Train on training data set
################################################################################
# learning_process = lr.GyroLearningBasedProcessing(train_params['res_dir'],
# train_params['tb_dir'], net_class, net_params, None,
# train_params['loss']['dt'])
# learning_process.train(dataset_class, dataset_params, train_params)
################################################################################
# Test on full data set
################################################################################
learning_process = lr.GyroLearningBasedProcessing(train_params['res_dir'],
train_params['tb_dir'], net_class, net_params, address=address,
dt=train_params['loss']['dt'])
learning_process.test(dataset_class, dataset_params, ['test'])