-
Notifications
You must be signed in to change notification settings - Fork 4
/
test.py
54 lines (43 loc) · 1.45 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""
test accuracy
edit by Hichens
Example:
>>>python main.py --model TSception
for more option, please see the utils.options for more paramters
"""
import os
import torch
import time
import sys; sys.path.append("..")
import visdom
import numpy as np
from utils.utils import *
from utils.options import opt
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
if __name__ == "__main__":
## Network
data = np.load(opt.test_data_path)
label = np.load(opt.test_label_path)
net = create_model(opt).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=opt.learning_rate)
if opt.best == True:
checkpoint_path = os.path.join(opt.checkpoint_dir, opt.model+'_best.pth')
else:
checkpoint_path = os.path.join(opt.checkpoint_dir, opt.model+'.pth')
print("load weight from: ", checkpoint_path)
checkpoint = torch.load(checkpoint_path)
load_checkpoint(net, checkpoint)
## Hyper Parameters
batch_size = opt.batch_size
test_loader = load_EEG_Datasets(data, label, batch_size, is_val=False)
## Test
print("testing on {} ...".format(device))
net.eval()
loss, acc = [], []
for i, (X, y) in enumerate(test_loader):
y_hat = net(X)
loss_fn = criterion(y_hat, y)
loss.append(loss_fn.item())
acc.append(Accuracy(y_hat, y))
print("testing loss: %.4f, testing accuracy: %.4f"%(sum(loss) / len(loss), sum(acc) / len(acc)))