forked from sebastiankmiec/NinaTools
-
Notifications
You must be signed in to change notification settings - Fork 0
/
new_data_example.py
46 lines (34 loc) · 1.91 KB
/
new_data_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from ninaeval.config import config_parser, config_setup
from ninaeval.utils.data_extract import extract_myo_all_csv
DATA_PATH = "all_data/"
MODEL_PATH = "all_models/"
def main():
# Reads JSON file via --json, or command line arguments:
config_param = config_parser.parse_config()
feat_extractor = config_setup.get_feat_extract(config_param.features)()
classifier = config_setup.get_model(config_param.model)(MODEL_PATH, feat_extractor)
dataset = config_setup.get_dataset(config_param.data)(DATA_PATH, feat_extractor, False)
#
# Use "extract_myo_all_csv" to load myo_all_data.csv data (obtained from our data collecton GUI) into "new_data".
#
if not dataset.load_dataset():
new_data = {}
#
# Your own data paths....
#
print("Extracting dataset features for training, and testing...")
# extract_myo_all_csv('/home/skmiec/Documents/ex5/a/myo_all_data.csv', new_data, "s11", "E1")
# extract_myo_all_csv('/home/skmiec/Documents/ex5/b/myo_all_data.csv', new_data, "s11", "E2")
# extract_myo_all_csv('/home/skmiec/Documents/ex5/c/myo_all_data.csv', new_data, "s11", "E3")
#
# extract_myo_all_csv('/home/skmiec/Documents/ex6/a/myo_all_data.csv', new_data, "s12", "E1")
# extract_myo_all_csv('/home/skmiec/Documents/ex6/b/myo_all_data.csv', new_data, "s12", "E2")
# extract_myo_all_csv('/home/skmiec/Documents/ex6/c/myo_all_data.csv', new_data, "s12", "E3")
dataset.create_dataset(new_data, False)
print("Training classifier on training dataset...")
classifier.train_model(dataset.train_features, dataset.train_labels, dataset.test_features, dataset.test_labels)
print("Testing classifier on testing dataset...")
print(classifier.perform_inference(dataset.test_features, dataset.test_labels))
classifier.save_model("/home/skmiec/Documents/")
if __name__ == "__main__":
main()