-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.py
146 lines (115 loc) · 3.94 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
main training
edit by Hichens
Example:
>>>python main.py --model TSception
for more option, please see the utils.options for more paramters
"""
import os
import torch
import time
import sys; sys.path.append("..")
import visdom
import numpy as np
from utils.utils import *
from utils.options import opt
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
## Network
data = np.load(opt.train_data_path)
label = np.load(opt.train_label_path)
net = create_model(opt).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=opt.learning_rate, \
weight_decay=opt.weight_decay)
checkpoint_path = os.path.join(opt.checkpoint_dir, opt.model+'.pth')
best_checkpoint_path = os.path.join(opt.checkpoint_dir, opt.model+'_best.pth')
if os.path.exists(checkpoint_path) and opt.pretrained == True:
checkpoint = torch.load(checkpoint_path)
load_checkpoint(net, checkpoint)
## Hyper Parameters
num_epochs = opt.num_epochs
batch_size = opt.batch_size
train_loader, val_loader = load_EEG_Datasets(data, label, batch_size, is_val=True)
## Visualize
vis = visdom.Visdom(env='main', port=opt.display_port)
## Train
print("training on {} ...".format(device))
train_loss, train_acc = [], []
val_loss, val_acc = [], []
best_train_acc, best_val_acc, patient = 0, 0, 0
for epoch in range(num_epochs):
epoch_start = time.time()
net.train()
loss, acc = [], []
for i, (X, y) in enumerate(train_loader):
y_hat = net(X) # batch_size X 2
loss_fn = criterion(y_hat, y)
if opt.normalized:
loss_r = regulization(net)
loss_fn += loss_r
loss.append(loss_fn.item())
optimizer.zero_grad()
loss_fn.backward()
optimizer.step()
acc.append(Accuracy(y_hat, y))
train_loss.append(sum(loss) / len(loss))
train_acc.append(sum(acc) / len(acc))
# print(train_loss)
print("epoch: %d, training loss: %.4f, training accuracy: %.4f, time: %d"%(
epoch+1, train_loss[epoch], train_acc[epoch], time.time() - epoch_start
))
## early stop
if best_train_acc < train_acc[epoch]:
best_train_acc = train_acc[epoch]
elif train_acc[epoch] < best_train_acc * 0.9:
patient += 1
if patient > opt.patient:
print("=> early stop!")
break
## Validation
net.eval()
loss, acc = [], []
for i, (X, y) in enumerate(val_loader):
y_hat = net(X)
loss_fn = criterion(y_hat, y)
loss.append(loss_fn.item())
acc.append(Accuracy(y_hat, y))
val_loss.append(sum(loss) / len(loss))
val_acc.append(sum(acc) / len(acc))
print("validate loss: %.4f, validate accuracy: %.4f"%(val_loss[epoch], val_acc[epoch]))
## Save the best state
if best_val_acc < val_acc[epoch]:
best_val_acc = val_acc[epoch]
state = {
'state_dict': net.state_dict(),
}
save_checkpoint(state, best_checkpoint_path)
## Visualize
vis.line(X=[_ for _ in range(epoch+1)], Y=np.column_stack((train_loss, train_acc, val_loss, val_acc)), win='train', \
opts={
'title': opt.model + '--train',
'dash': np.array(['solid', 'solid', 'solid', 'solid']),
'legend': ['train_loss', 'train_acc', 'val_loss', 'val_acc'],
'showlegend': True,
})
print("training done !")
## Save checkpoint
net.eval()
state = {
'state_dict': net.state_dict(),
}
save_checkpoint(state, checkpoint_path)
## Test
data = np.load(opt.test_data_path)
label = np.load(opt.test_label_path)
test_loader = load_EEG_Datasets(data, label, batch_size=opt.batch_size, is_val=False)
print("=="*20)
print("testing on {} ...".format(device))
net.eval()
loss, acc = [], []
for i, (X, y) in enumerate(test_loader):
y_hat = net(X)
loss_fn = criterion(y_hat, y)
loss.append(loss_fn.item())
acc.append(Accuracy(y_hat, y))
print("testing loss: %.4f, testing accuracy: %.4f"%(sum(loss) / len(loss), sum(acc) / len(acc)))