forked from koren-v/Interpret
-
Notifications
You must be signed in to change notification settings - Fork 0
/
saliency_interpreter.py
145 lines (116 loc) · 5.4 KB
/
saliency_interpreter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
from torch.nn.functional import softmax
import matplotlib
import matplotlib.pyplot as plt
class SaliencyInterpreter:
def __init__(self,
model,
criterion,
tokenizer,
show_progress=True,
**kwargs):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = model.to(self.device)
self.model.eval()
self.criterion = criterion
self.tokenizer = tokenizer
self.show_progress = show_progress
self.kwargs = kwargs
# to save outputs in saliency_interpret
self.batch_output = None
def _get_gradients(self, batch):
# set requires_grad to true for all parameters, but save original values to
# restore them later
original_param_name_to_requires_grad_dict = {}
for param_name, param in self.model.named_parameters():
original_param_name_to_requires_grad_dict[param_name] = param.requires_grad
param.requires_grad = True
embedding_gradients = []
hooks = self._register_embedding_gradient_hooks(embedding_gradients)
loss = self.forward_step(batch)
self.model.zero_grad()
loss.backward()
for hook in hooks:
hook.remove()
# restore the original requires_grad values of the parameters
for param_name, param in self.model.named_parameters():
param.requires_grad = original_param_name_to_requires_grad_dict[param_name]
return embedding_gradients[0]
def _register_embedding_gradient_hooks(self, embedding_gradients):
"""
Registers a backward hook on the
Used to save the gradients of the embeddings for use in get_gradients()
When there are multiple inputs (e.g., a passage and question), the hook
will be called multiple times. We append all the embeddings gradients
to a list.
"""
def hook_layers(module, grad_in, grad_out):
embedding_gradients.append(grad_out[0])
backward_hooks = []
embedding_layer = self.model.get_input_embeddings()
backward_hooks.append(embedding_layer.register_backward_hook(hook_layers))
return backward_hooks
def colorize(self, instance, skip_special_tokens=False):
special_tokens = self.tokenizer.eos_token, self.tokenizer.bos_token
word_cmap = matplotlib.cm.Blues
prob_cmap = matplotlib.cm.Greens
template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
colored_string = ''
for word, color in zip(instance['tokens'], instance['grad']):
if word in special_tokens and skip_special_tokens:
continue
# handle wordpieces
word = word.replace("##", "") if "##" in word else ' ' + word
color = matplotlib.colors.rgb2hex(word_cmap(color)[:3])
colored_string += template.format(color, word)
colored_string += template.format(0, " Label: {} |".format(instance['label']))
prob = instance['prob']
color = matplotlib.colors.rgb2hex(prob_cmap(prob)[:3])
colored_string += template.format(color, "{:.2f}%".format(instance['prob']*100)) + '|'
return colored_string
def forward_step(self, batch):
"""
If your model receive inputs in another way or you computing not
like in this example simply override this method. It should return the batch loss
:param batch: batch returned by dataloader
:return: torch.Tensor: batch loss
"""
input_ids = batch.get('input_ids').to(self.device)
attention_mask = batch.get("attention_mask").to(self.device)
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
label = torch.argmax(outputs, dim=1)
batch_losses = self.criterion(outputs, label)
loss = torch.mean(batch_losses)
self.batch_output = [input_ids, outputs]
return loss
def update_output(self):
"""
You can also override this method if you want to change the format
of outputs. (e.g. store just gradients)
:return: batch_output
"""
input_ids, outputs, grads = self.batch_output
probs = softmax(outputs, dim=-1)
probs, labels = torch.max(probs, dim=-1)
tokens = [
self.tokenizer.convert_ids_to_tokens(input_ids_)
for input_ids_ in input_ids
]
embedding_grads = grads.sum(dim=2)
# norm for each sequence
norms = torch.norm(embedding_grads, dim=1, p=1)
# normalizing
for i, norm in enumerate(norms):
embedding_grads[i] = torch.abs(embedding_grads[i]) / norm
batch_output = []
iterator = zip(tokens, probs, embedding_grads, labels)
for example_tokens, example_prob, example_grad, example_label in iterator:
example_dict = dict()
# as we do it by batches we has a padding so we need to remove it
example_tokens = [t for t in example_tokens if t != self.tokenizer.pad_token]
example_dict['tokens'] = example_tokens
example_dict['grad'] = example_grad.cpu().tolist()[:len(example_tokens)]
example_dict['label'] = example_label.item()
example_dict['prob'] = example_prob.item()
batch_output.append(example_dict)
return batch_output