-
Notifications
You must be signed in to change notification settings - Fork 5
/
generate_vectors.py
51 lines (36 loc) · 2.12 KB
/
generate_vectors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from datasets.datasets import load_data
import os
from utils.utils import load_parameters, set_parameter
import time
import pickle
if __name__ == "__main__":
saved_vectors_directory = "sample_dataset_saved_feature_vectors"
if not os.path.exists(saved_vectors_directory):
os.mkdir(saved_vectors_directory)
malicious_vector_filepath = os.path.join(saved_vectors_directory, "malicious")
benign_vector_filepath = os.path.join(saved_vectors_directory, "benign")
if not os.path.exists(malicious_vector_filepath):
os.mkdir(malicious_vector_filepath)
if not os.path.exists(benign_vector_filepath):
os.mkdir(benign_vector_filepath)
parameters = load_parameters("parameters.ini")
# This flag must be on to generate, changes return of PortableExecutableDataset to return filepath as well
set_parameter("parameters.ini", "dataset", "generate_feature_vector_files", "True")
set_parameter("parameters.ini", "hyperparam", "training_batch_size", "1")
set_parameter("parameters.ini", "hyperparam", "test_batch_size", "1")
train_dataloader_dict, valid_dataloader_dict, test_dataloader_dict, num_features = load_data(parameters)
print(len(train_dataloader_dict['malicious'].dataset) +
len(test_dataloader_dict['malicious'].dataset) + len(valid_dataloader_dict['malicious'].dataset))
print(len(train_dataloader_dict['benign'].dataset) + len(test_dataloader_dict['benign'].dataset) +
len(valid_dataloader_dict['benign'].dataset))
for i, data_dict in enumerate([train_dataloader_dict, valid_dataloader_dict, test_dataloader_dict]):
for filetype in data_dict:
dataloader = data_dict[filetype]
for index, data in enumerate(dataloader):
print(index, filetype)
vector, label, filepath = data
filename = filepath[0].split("/")[-1]
if filetype == 'malicious':
pickle.dump(vector, open(os.path.join(malicious_vector_filepath, filename), 'wb'))
else:
pickle.dump(vector, open(os.path.join(benign_vector_filepath, filename), 'wb'))