-
Notifications
You must be signed in to change notification settings - Fork 8
/
integrated_gradient.py
90 lines (68 loc) · 3.09 KB
/
integrated_gradient.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
from tqdm import tqdm
from saliency_interpreter import SaliencyInterpreter
class IntegratedGradient(SaliencyInterpreter):
"""
Interprets the prediction using Integrated Gradients (https://arxiv.org/abs/1703.01365)
Registered as a `SaliencyInterpreter` with name "integrated-gradient".
"""
def __init__(self,
model,
criterion,
tokenizer,
num_steps=20,
show_progress=True,
**kwargs):
super().__init__(model, criterion, tokenizer, show_progress, **kwargs)
# Hyperparameters
self.num_steps = num_steps
def saliency_interpret(self, test_dataloader):
instances_with_grads = []
iterator = tqdm(test_dataloader) if self.show_progress else test_dataloader
for batch in iterator:
# we will store there batch outputs such as gradients, probability, tokens
# so as each of them are used in different places, for convenience we will create
# it as attribute:
self.batch_output = []
self._integrate_gradients(batch)
batch_output = self.update_output()
instances_with_grads.extend(batch_output)
return instances_with_grads
def _register_forward_hook(self, alpha, embeddings_list):
"""
Register a forward hook on the embedding layer which scales the embeddings by alpha. Used
for one term in the Integrated Gradients sum.
We store the embedding output into the embeddings_list when alpha is zero. This is used
later to element-wise multiply the input by the averaged gradients.
"""
def forward_hook(module, inputs, output):
# Save the input for later use. Only do so on first call.
if alpha == 0:
embeddings_list.append(output.squeeze(0).clone().detach())
# Scale the embedding by alpha
output.mul_(alpha)
embedding_layer = self.get_embeddings_layer()
handle = embedding_layer.register_forward_hook(forward_hook)
return handle
def _integrate_gradients(self, batch):
ig_grads = None
# List of Embedding inputs
embeddings_list = []
# Exclude the endpoint because we do a left point integral approximation
for alpha in np.linspace(0, 1.0, num=self.num_steps, endpoint=False):
# Hook for modifying embedding value
handle = self._register_forward_hook(alpha, embeddings_list)
grads = self._get_gradients(batch)
handle.remove()
# Running sum of gradients
if ig_grads is None:
ig_grads = grads
else:
ig_grads = ig_grads + grads
# Average of each gradient term
ig_grads /= self.num_steps
# Gradients come back in the reverse order that they were sent into the network
embeddings_list.reverse()
# Element-wise multiply average gradient by the input
ig_grads *= embeddings_list[0]
self.batch_output.append(ig_grads)