This repository has been archived by the owner on Apr 26, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_type2_convGru.py
54 lines (38 loc) · 1.95 KB
/
test_type2_convGru.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import utils as utils
from sweaty_net_2_outputs import SweatyNet1
from conv_gru import ConvGruCellPreConv
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--loadSweaty', required=True, help='path to pretrained Sweaty model')
parser.add_argument('--loadGru', required=True, help='path to pretrained Gru cell')
parser.add_argument('--testSet', required=True, help='dataroot of the test set')
parser.add_argument('--trainSet', required=True, help='dataroot of the train set')
parser.add_argument('--batch_size', type=int, default=4, help='batch size')
parser.add_argument('--downsample', type=int, default=4, help='downsample')
parser.add_argument('--p', type=float, default=0.001, help='percentage of abs threshold')
parser.add_argument('--alpha', type=int, default=1000, help='multiplication factor for the teacher signals')
parser.add_argument('--seq_len', type=int, required=True, help='length of the sequence')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
opt = parser.parse_args()
downsample = opt.downsample
batch_size = opt.batch_size
testset = opt.testSet
trainset = opt.trainSet
pretrained_model = opt.loadSweaty
gru_cell = opt.loadGru
sweaty = SweatyNet1()
sweaty.load_state_dict(torch.load(pretrained_model))
sweaty.eval()
convGru = ConvGruCellPreConv(89, 1, device=device)
convGru.load_state_dict(torch.load(gru_cell))
testset = utils.SoccerBallDataset(testset + "data.csv", testset, downsample=downsample, alpha= opt.alpha)
trainset = utils.SoccerBallDataset(trainset + "data.csv", trainset, downsample=downsample, alpha= opt.alpha)
sweaty.eval()
convGru.eval()
threshold = utils.get_abs_threshold(trainset, opt.p)
metrics = utils.evaluate_type2_sweaty_gru_model(sweaty, convGru, device, testset, threshold, verbose=True, seq_len=opt.seq_len)
rc = metrics['tps']/(metrics['tps'] + metrics['fns'])
fdr = metrics['fps']/(metrics['fps'] + metrics['tps'])
print("RC: {}, FDR: {}".format(rc, fdr))