-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_models.py
45 lines (34 loc) · 1.01 KB
/
main_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from anemia_data import anemia_data
from model_comparison import *
from models import *
data = anemia_data()
default_test_size = 0.5
data.get_heatmap()
data.plot_results()
data.plot_hemoglobin()
data.plot_MCHC()
data.plot_MCV()
# logistic regression to predict anemia
model_1 = anemia_logistic_regression(data, 100, default_test_size)
model_1.predict()
print("\nLogistic Regression metrics")
model_1.print_metrics()
model_1.get_confusion_matrix()
print("\n\n")
# decision tree to predict anemia
model_2 = anemia_decision_tree(data, 50, default_test_size)
model_2.predict()
print("Decision tree metrics")
model_2.print_metrics()
model_2.get_confusion_matrix()
print("\n\n")
# k-nearest neighbors to predict anemia
model_3 = anemia_knn(data, default_test_size, 21)
model_3.predict()
print("K-Nearest Neighbors metrics")
model_3.print_metrics()
model_3.get_confusion_matrix()
# model comparisons
metrics_graph_lr(data, default_test_size)
metrics_graph_dt(data, default_test_size)
metrics_graph_knn(data, default_test_size)