-
Notifications
You must be signed in to change notification settings - Fork 81
/
testModel_dataAug.lua
122 lines (99 loc) · 3.22 KB
/
testModel_dataAug.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
function testModel(allData,model,valInds,epochError)
print('testing corrected verison 3')
local timerTest = torch.Timer()
local dtype = 'torch.DoubleTensor'
if opt.useCUDA then
dtype = 'torch.CudaTensor'
end
local criterion = nn.ClassNLLCriterion():type(dtype)
model:evaluate()
-- push the validation data through the network
local nValPrograms = #valInds
local valError = 0
local correct = 0
local confmat = torch.zeros(2,2)
local lens = torch.zeros(nValPrograms)
-- We need to make sure the rare-class is regarded as positive
-- This means the f-score etc will be corectly calculated
-- When reading the data benign is labelled as 1 and malware as 2
local nBenign = 0
local nMalware = 0
for k = 1,nValPrograms do
if allData.label[valInds[k]] == 1 then
nBenign = nBenign + 1
else
nMalware = nMalware + 1
end
end
local positiveLabel = 1
if nMalware < nBenign then
positiveLabel = 2
end
print('Test Stats : nMalware ',nMalware, ' nBenign ',nBenign, ' positiveLabel ',positiveLabel)
--local valBatch = torch.zeros(1,opt.programLen):type(dtype)
local valLabel = torch.zeros(1):type(dtype)
for k = 1,nValPrograms do
valLabel[{1}] = allData.label[valInds[k]]
--valBatch[{{1},{}}] = allData.program[valInds[k]]
local currProgramPtr = allData.programStartPtrs[valInds[k]]
local currProgramLen = allData.programLengths[valInds[k]]
local netOutputProb = torch.zeros(1,2)
local nDataAug = 10
for j = 1,nDataAug do
local valBatch
if currProgramLen > opt.maxSequenceLength then
valBatch = torch.zeros(1,opt.maxSequenceLength):type(dtype)
local rndPtr = torch.floor(torch.rand(1)[1] * (currProgramLen - opt.maxSequenceLength - 1))
valBatch[{{1},{}}] = allData.program[{{currProgramPtr + rndPtr,currProgramPtr + rndPtr + opt.maxSequenceLength - 1}}]
else
valBatch = torch.zeros(1,currProgramLen):type(dtype)
valBatch[{{1},{}}] = allData.program[{{currProgramPtr,currProgramPtr + currProgramLen - 1}}]
end
-- if currProgramLen > opt.maxSequenceLength then
-- currProgramLen = opt.maxSequenceLength
-- end
-- local valBatch = torch.zeros(1,currProgramLen):type(dtype)
-- valBatch[{{1},{}}] = allData.program[{{currProgramPtr,currProgramPtr + currProgramLen - 1}}]
local netOutput = model:forward(valBatch)
valError = valError + criterion:forward(netOutput,valLabel)
netOutputProb = netOutputProb + nn.Exp():forward(netOutput:double())
end
local v,i = torch.max(netOutputProb,2)
local pred = i[{1,1}]
local gt = allData.label[valInds[k]]
if pred == gt then
correct = correct + 1;
end
confmat[pred][gt] = confmat[pred][gt] + 1
end
valError = valError / nValPrograms
local tp = 0
local fp = 0
local fn = 0
if positiveLabel == 1 then
tp = confmat[1][1]
fp = confmat[1][2]
fn = confmat[2][1]
else
tp = confmat[2][2]
fp = confmat[2][1]
fn = confmat[1][2]
end
local testResult = {
-- tp = tp,
-- fp = fp,
-- fn = fn,
prec = tp / (tp + fp),
recall = tp / (tp + fn),
fscore = (2 * tp) / ((2 * tp) + fp + fn),
accuracy = correct/nValPrograms,
testError = valError,
}
local time = timerTest:time().real
model:training()
-- clean up
valBatch = nil
valLabel = nil
collectgarbage()
return testResult,confmat,time
end