-
Notifications
You must be signed in to change notification settings - Fork 0
/
make_results_plots.py
120 lines (113 loc) · 7.12 KB
/
make_results_plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import matplotlib.pyplot as plt
import numpy as np
from pandas import read_csv, concat
import os
# Setup
os.chdir('result_csv_files')
my_width = 0.20 # The width of each individual bar
plt.rcParams.update({'font.size': 32})
# Dictionary of keywords for the error bars
ebar_kw = {'elinewidth': 0.75, 'capsize': 1.5}
# Dictionary of colors to ensure colors are consistent even when ANN-3 isn't present
colors = plt.rcParams['axes.prop_cycle'].by_key()['color'][:6]
labels = ['Exp', 'ANN-2', 'ANN-3', 'EN/RR/PLS', 'Our ANN', 'Train Mean']
color_dict = dict(zip(labels, colors))
# Lists for the Our MLP pred vs. real plot
our_ANN_pred = []
exp_data = []
for myfile in os.scandir():
filename, ext = os.path.splitext(myfile.name)
# Run plotting only if current file is a .csv of results - ignore other files (such as .png / .svg generated by this code)
if ext == '.csv' and filename == 'Relative_Errors':
data = read_csv(myfile.name)
fig, ax = plt.subplots(figsize = (16, 9), dpi = 500)
# ANN-3 has some NaN values, so we must plot the models one-by-one
for idx in range(data.shape[1]):
temp = plt.violinplot(data.iloc[:, idx].dropna(), [idx], showmedians = True)
for elem in temp.values(): # Have to use this roundabout way bc plt.violinplot() doesn't support color assignment
if isinstance(elem, list): # temp['bodies'] is a list with a single element
elem = elem[0]
elem.set_color(colors[idx+1])
# Plot housekeeping
my_Xticklabels = [''] + data.columns.tolist()
ax.set_xticklabels(my_Xticklabels)
ax.set_ylabel('Percent Relative Error')
ax.set_ylim(ax.get_ylim()[0], 101)
plt.tight_layout()
plt.savefig(f'{filename}.svg')
# We already plot Fc_EPO when plotting Fc_DAO, so we're skipping it in the if statement
elif ext == '.csv' and filename != 'Fc_EPO_TestResults':
data = read_csv(myfile.name, index_col = 0)
data = data.iloc[:, ~np.isnan(data.loc['Our ANN']).values] # Removing glycans without ANN predictions (that is, the minor glycans)
if filename == 'Fc_DAO_TestResults':
data2 = read_csv('Fc_EPO_TestResults.csv', index_col = 0) # Also getting Fc_EPO to plot both
data2 = data2.iloc[:, ~np.isnan(data2.loc['Our ANN']).values] # Removing glycans without ANN predictions (that is, the minor glycans)
data = concat((data, data2), axis = 1) # Concatenating both DFs into one
data.columns = ['Fc_DAO-GnGnF', 'Fc_EPO-GnGnF']
filename = 'Fc-Domain_Both_Proteins_TestResults' # To change the plot title and filename
protein_name = ''
fig, ax = plt.subplots(figsize = (16, 9), dpi = 500)
# The horizontal spacing between each glycan. -2 because there are 2 stdev rows, +1 so there's some blank space between each glycan
x_points = np.arange(data.shape[1]) * (data.shape[0] - 2 + 1)*my_width
for row_idx in range(data.shape[0] - 2):
bar_shift = (row_idx - (data.shape[0] - 2)/2 + 0.5) * my_width
this_color = color_dict[data.index[row_idx]] # Selecting colors for consistency
if data.index[row_idx] == 'Exp': # Experimental data
temp = plt.bar(x_points + bar_shift, data.iloc[row_idx, :], my_width, color = this_color, label = data.index[row_idx], yerr = data.iloc[-2, :], error_kw = ebar_kw)
exp_data.extend(data.iloc[row_idx, :].values)
elif data.index[row_idx] == 'EN/RR/PLS': # SPA models
temp = plt.bar(x_points + bar_shift, data.iloc[row_idx, :], my_width, color = this_color, label = data.index[row_idx], yerr = data.iloc[-1, :], error_kw = ebar_kw)
elif data.index[row_idx] == 'Our ANN': # Our ANN
temp = plt.bar(x_points + bar_shift, data.iloc[row_idx, :], my_width, color = this_color, label = data.index[row_idx])
our_ANN_pred.extend(data.iloc[row_idx, :].values)
elif 'ANN' in data.index[row_idx] or data.index[row_idx] == 'Train Mean': # Our ANN / Kotidis and Kontoravdi's ANNs or Train Mean
temp = plt.bar(x_points + bar_shift, data.iloc[row_idx, :], my_width, color = this_color, label = data.index[row_idx]) # Same as above, but separating due to the relative error plot
# Plot housekeeping
if '_'.join(filename.split('_')[:2]) in {'Asn_24', 'Asn_38', 'Asn_83'}:
protein_name = ' of EPO-Fc'
elif '_'.join(filename.split('_')[:2]) in {'Asn_110', 'Asn_168', 'Asn_538', 'Asn_745'}:
protein_name = ' of Fc-DAO'
ax.set_title(' '.join(filename.split('_')[:-1]) + protein_name)
ax.set_ylim(0, ax.get_ylim()[1]) # Setting the minimum y-axis to 0
ax.set_ylabel('Glycan Distribution')
ax.set_xlim(x_points[0] - ((data.shape[0] - 2)/2 + 0.5)*my_width, x_points[-1] + ((data.shape[0] - 2)/2 + 0.5)*my_width) # Setting the minimum x-axis
ax.set_xticks(x_points) # Setting the X-ticks at the right locations
ax.set_xticklabels(data.columns.to_list())
plt.tight_layout()
plt.savefig(f'{filename}.svg')
# Our ANN relative errors plot
fig, ax = plt.subplots(figsize = (9, 9), dpi = 500)
plt.scatter(exp_data, our_ANN_pred)
ax.set_title('Test Predictions')
ax.set_xlabel('Real Values')
ax.set_ylabel('Predicted Values')
axis_min_lim = np.minimum(np.min(exp_data), np.min(our_ANN_pred))
axis_min_lim = np.minimum(axis_min_lim, 0) # If all values are positive, start from 0
axis_max_lim = np.maximum(np.max(exp_data), np.max(our_ANN_pred))
# Perfect prediction line (just a 45° line)
plt.plot([axis_min_lim, axis_max_lim+0.01], [axis_min_lim, axis_max_lim+0.01], color = 'k', linestyle = '--', linewidth = 1)
# 10% relative error lines
plt.plot([axis_min_lim, axis_max_lim+0.01], [1.1*axis_min_lim, 1.1*(axis_max_lim+0.01)], color = 'g', linestyle = '--', linewidth = 1)
plt.plot([axis_min_lim, axis_max_lim+0.01], [0.9*axis_min_lim, 0.9*(axis_max_lim+0.01)], color = 'g', linestyle = '--', linewidth = 1)
# 20% relative error lines
plt.plot([axis_min_lim, axis_max_lim+0.01], [1.2*axis_min_lim, 1.2*(axis_max_lim+0.01)], color = 'y', linestyle = '--', linewidth = 1)
plt.plot([axis_min_lim, axis_max_lim+0.01], [0.8*axis_min_lim, 0.8*(axis_max_lim+0.01)], color = 'y', linestyle = '--', linewidth = 1)
# Making the plot a square
ax.set_xlim(axis_min_lim - 0.01, axis_max_lim + 0.01)
ax.set_ylim(axis_min_lim - 0.01, axis_max_lim + 0.01)
# TODO: legend; maybe change lines to 5% and 10% error
plt.tight_layout()
plt.savefig('../result_csv_files/Our_ANN_Relative_Errors.svg')
ax.set_xlim(ax.get_xlim()[0], 0.2)
ax.set_ylim(ax.get_ylim()[0], 0.2)
plt.savefig('../result_csv_files/Our_ANN_Relative_Errors_Cropped.svg')
# Nested validation relative errors plot
data = read_csv('../result_csv_files_nested/Relative_Errors.csv')
fig, ax = plt.subplots(figsize = (16, 9), dpi = 500)
plt.violinplot(data, showmedians = True)
# Plot housekeeping
my_Xticklabels = ['', 'EN/RR/PLS', '', 'EN/RR/PLS\n(Nested)', '', 'Our ANN', '', 'Our ANN\n(Nested)']
ax.set_xticklabels(my_Xticklabels)
ax.set_ylabel('Percent Relative Error')
plt.tight_layout()
plt.savefig('../result_csv_files_nested/Relative_Errors_Nested.svg')