-
Notifications
You must be signed in to change notification settings - Fork 0
/
interpretable_pdf.py
350 lines (279 loc) · 14 KB
/
interpretable_pdf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
from fpdf import FPDF
import numpy as np
from nltk.stem import WordNetLemmatizer
import re
from explainability import (get_ti_feature_contributions_for_instance_i,
run_tree_interpreter)
class InterpretablePDF:
"""
Class to produce formatted vignettes that indicate which text elements are contributing to
any given classification.
Currently this class only works with the CAP Prostate Cancer data, because the formatting
corresponds to the unique structure of this text data. However, it could easily be adapted to
work with other textual data (such as LeDeR).
"""
def __init__(self,
classifier,
x_data,
y_data,
feature_columns,
base_font_size=12, line_height=8,
header_col_width=100, legend_offset=47.5,
legend_offset_2=63,
top_n_features=None,
contributions=None):
self.font_size = base_font_size
self.line_height = line_height
self.header_col_width = header_col_width
self.legend_offset = legend_offset
self.legend_offset_2 = legend_offset_2
self.top_n = top_n_features
self.pdf = None
self.original_data = None
self.feature_columns = feature_columns
self.X = x_data
self.y = y_data
self.clf = classifier
_, _, self.contributions = (run_tree_interpreter(self.clf,
self.X)
if contributions is None else (
None, None, contributions))
self.stemmer = WordNetLemmatizer()
self.section_headers = ['Clinical features at diagnosis',
'Treatments received',
'Prostate cancer progression',
'Progression of co-morbidities',
'End of life']
self.vignette_column_names = [
'Gleason Score at diagnosis (with dates)',
'Clinical stage (TNM)',
'Pathological stage (TNM)',
'Co-morbidities with dates of diagnosis',
'Other primary cancers with dates of diagnosis',
'PSA level at diagnosis with dates',
'Radiological evidence of local spread at diagnosis',
'Radiological evidence of metastases at diagnosis',
'Initial treatments (dates)',
'Hormone therapy (start date)',
'Maximum androgen blockade (start date)',
'Orchidectomy (date)',
'Chemotherapy (start date)',
'Treatment for complications of treating prostate cancer with dates (if available)',
'Serial PSA levels (dates)',
'Serum testosterone',
'Radiological evidence of metastases',
'Other indications or complications of disease progression',
'Date of recurrence following radical surgery or radiotherapy',
'Palliative care referrals and treatments',
'Treatment/ admission for co-morbidity with dates (if available)',
'Symptoms in last 3-6 months (i.e. bone pain, weight loss, cachexia,\
loss of appetite, obstructive uraemia)',
'Last consultation: speciality & date',
'Was a DS1500 report issued?',
'Post mortem findings']
def create_pdf(self, case_id, original_data, filename):
self.original_data = original_data
self.pdf = FPDF()
self.pdf.add_page()
self.pdf.set_font('Arial', 'B', self.font_size)
self.pdf.cell(w=0,
h=self.line_height,
txt='Interpretable Vignette Classification for Cause of Death Review',
border=0,
ln=0,
align='C', fill=False, link='')
self.pdf.set_font('')
self.pdf.ln()
self.pdf.ln()
y = self.pdf.get_y()
self.pdf.multi_cell(w=self.header_col_width,
h=self.line_height,
txt='Study ID number : %s\nDate of death : %s\nDate of diagnosis : %s' % (
original_data['cp1random_id_5_char'],
original_data['cnr19datedeath'].date(),
original_data['cnr_date_pca_diag'].date()),
border=0,
align='L',
fill=False)
self.pdf.y = y
self.pdf.x = self.header_col_width
self.pdf.multi_cell(w=self.header_col_width,
h=self.line_height,
txt='Predicted death code: %d (%.2f)\nActual death code: %d\nCOD route: %d' % (
self.clf.best_estimator_.predict(self.X)[case_id],
self.clf.best_estimator_.predict_proba(self.X)[case_id][1],
original_data.pca_death_code,
original_data.cp1do_cod_route),
border=0,
align='R', fill=False)
self.pdf.set_line_width(0.5)
self.pdf.line(10, self.pdf.get_y(), 210 - 10, self.pdf.get_y())
self.pdf.set_line_width(0.2)
self.write_legend(case_id)
for ci, col in enumerate(self.feature_columns):
self.pdf.set_text_color(0, 0, 0)
self.pdf.set_font_size(self.font_size)
if ci in [0, 8, 14, 20, 21]:
self.print_section_header(0)
self.pdf.line(10, self.pdf.get_y(), 210 - 10, self.pdf.get_y())
self.pdf.write(self.line_height, self.vignette_column_names[ci] + ': ')
if 'palliative' not in col:
text = str(original_data[col])
self.print_paragraph(text, case_id)
self.pdf.output(filename)
def get_font_size_and_color(self,
contribution,
min_contribution,
max_contribution,
shrink=0.5):
c = (255, 160, 0) if contribution < 0 else (0, 0, 255)
s = (
(self.font_size * shrink) + 1.5 * self.font_size
* (np.absolute(contribution) - min_contribution)
/ float(max_contribution - min_contribution)
)
return s, c
def legend_entry(self, fimps, fimp, align):
size, color = self.get_font_size_and_color(fimp.contribution,
fimps.magnitude.min(),
fimps.magnitude.max())
self.pdf.set_text_color(*color)
self.pdf.set_font_size(size)
self.pdf.cell(w=self.legend_offset,
h=self.line_height,
txt=fimp.feature,
border=0, ln=0,
align=align, fill=False, link='')
def legend_label(self, text, align):
self.pdf.cell(w=self.legend_offset_2,
h=self.line_height,
txt=text, border=0, ln=0,
align=align, fill=False, link='')
def write_legend(self, case_id):
self.pdf.cell(w=0,
h=self.line_height,
txt='Feature contribution legend',
border=0, ln=0,
align='C', fill=False, link='')
self.pdf.ln()
fimps = get_ti_feature_contributions_for_instance_i(case_id,
self.contributions,
self.clf).sort_values(by='magnitude',
ascending=False)
fimps = fimps.head(self.top_n) if self.top_n is not None else fimps
fimp = fimps.loc[fimps.contribution.idxmin()]
self.legend_entry(fimps, fimp, align='L')
fimp = fimps.loc[fimps.contribution < 0]
fimp = fimp.loc[fimp.contribution.idxmax()]
self.legend_entry(fimps, fimp, align='R')
fimp = fimps.loc[fimps.contribution > 0]
fimp = fimp.loc[fimp.contribution.idxmin()]
self.legend_entry(fimps, fimp, align='L')
fimp = fimps.loc[fimps.contribution.idxmax()]
self.legend_entry(fimps, fimp, align='R')
self.pdf.ln()
self.pdf.set_text_color(0, 0, 0)
self.pdf.set_font('Arial', '', self.font_size * .6)
self.legend_label('Largest negative contribution', 'L')
self.legend_label('Smallest contributions', 'C')
self.legend_label('Largest positive contribution', 'R')
self.pdf.set_font('Arial', '', self.font_size)
def print_section_header(self, section):
self.pdf.ln()
self.pdf.set_line_width(0.5)
self.pdf.line(10, self.pdf.get_y(), 210 - 10, self.pdf.get_y())
self.pdf.set_line_width(0.2)
self.pdf.set_font('Arial', 'B', self.font_size)
self.pdf.write(self.line_height, self.section_headers[section] + '\n')
self.pdf.set_font('')
# REFACTOR!
def print_paragraph(self, text, i,
base_color=(128, 128, 128)):
fimps = get_ti_feature_contributions_for_instance_i(i,
self.contributions,
self.clf)
fimps.sort_values(by='magnitude', inplace=True, ascending=False)
fimps = fimps.head(self.top_n)
old_word = ''
old_tr_word = ''
old_bigram = ''
old_bigram_contribution = None
old_color = base_color
old_size = self.font_size
words = text.split(' ')
words.append(' .')
for word in words:
tr_word = self.transform_text(word)
feat_tr = old_tr_word + ' ' + tr_word
contribution_bi = (fimps.loc[fimps.feature == feat_tr]
.iloc[0].contribution
if feat_tr in list(fimps.feature)
else None)
magnitude_bi = np.absolute(contribution_bi) if contribution_bi is not None else 0
feat_tr = old_tr_word
contribution_uni = (fimps.loc[fimps.feature == feat_tr]
.iloc[0].contribution
if feat_tr in list(fimps.feature)
else None)
magnitude_uni = np.absolute(contribution_uni) if contribution_uni is not None else 0
if contribution_bi and magnitude_bi > magnitude_uni:
# print('bigram: ', old_tr_word)
feat_tr = old_tr_word + ' ' + tr_word
feat = old_word + ' ' + word
contribution = fimps.loc[fimps.feature == feat_tr].iloc[0].contribution
size, color = self.get_font_size_and_color(contribution,
fimps.magnitude.min(),
fimps.magnitude.max())
self.pdf.set_text_color(*color)
self.pdf.set_font_size(size)
feat = feat.encode('latin-1', 'replace').decode('latin-1')
self.pdf.write(self.line_height, feat + ' ')
old_word = ''
old_tr_word = ''
old_color = base_color
old_size = self.font_size
elif contribution_uni and magnitude_uni > magnitude_bi:
# print('unigram: ', old_tr_word)
feat_tr = old_tr_word
feat = old_word
contribution = fimps.loc[fimps.feature == feat_tr].iloc[0].contribution
size, color = self.get_font_size_and_color(contribution,
fimps.magnitude.min(),
fimps.magnitude.max())
self.pdf.set_text_color(*color)
self.pdf.set_font_size(size)
feat = feat.encode('latin-1', 'replace').decode('latin-1')
self.pdf.write(self.line_height, feat + ' ')
old_word = word
old_tr_word = tr_word
old_color = base_color
old_size = self.font_size
else:
self.pdf.set_text_color(*old_color)
self.pdf.set_font_size(old_size)
w = old_word.encode('latin-1', 'replace').decode('latin-1')
self.pdf.write(self.line_height, w + ' ')
old_word = word
old_tr_word = tr_word
old_color = base_color
old_size = self.font_size
self.pdf.ln()
def transform_text(self, text):
# Remove all the special characters
document = re.sub(r'\W', ' ', str(text))
# remove all single characters
document = re.sub(r'\s+[a-zA-Z]\s+', ' ', document)
# document = re.sub(r'\s+[a-zA-Z]\s+', ' ', str(X[sen]))
# Remove single characters from the start
document = re.sub(r'\^[a-zA-Z]\s+', ' ', document)
# Substituting multiple spaces with single space
document = re.sub(r'\s+', ' ', document, flags=re.I)
# Removing prefixed 'b'
document = re.sub(r'^b\s+', '', document)
# Converting to Lowercase
document = document.lower()
# Lemmatization
document = document.split()
document = [self.stemmer.lemmatize(word) for word in document]
document = ' '.join(document)
return document