forked from koren-v/Interpret
-
Notifications
You must be signed in to change notification settings - Fork 0
/
smooth_gradient.py
77 lines (60 loc) · 2.52 KB
/
smooth_gradient.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from tqdm import tqdm
from saliency_interpreter import SaliencyInterpreter
class SmoothGradient(SaliencyInterpreter):
"""
Interprets the prediction using SmoothGrad (https://arxiv.org/abs/1706.03825)
Registered as a `SaliencyInterpreter` with name "smooth-gradient".
"""
def __init__(self,
model,
criterion,
tokenizer,
stdev=0.01,
num_samples=20,
show_progress=True,
**kwargs):
super().__init__(model, criterion, tokenizer, show_progress, **kwargs)
# Hyperparameters
self.stdev = stdev
self.num_samples = num_samples
def saliency_interpret(self, test_dataloader):
instances_with_grads = []
iterator = tqdm(test_dataloader) if self.show_progress else test_dataloader
for batch in iterator:
# we will store there batch outputs such as gradients, probability, tokens
# so as each of them are used in different places, for convenience we will create
# it as attribute:
self.batch_output = []
self._smooth_grads(batch)
batch_output = self.update_output()
instances_with_grads.extend(batch_output)
return instances_with_grads
def _register_forward_hook(self, stdev: float):
"""
Register a forward hook on the embedding layer which adds random noise to every embedding.
Used for one term in the SmoothGrad sum.
"""
def forward_hook(module, inputs, output):
# Random noise = N(0, stdev * (max-min))
scale = output.detach().max() - output.detach().min()
noise = torch.randn(output.shape).to(output.device) * stdev * scale
# Add the random noise
output.add_(noise)
# Register the hook
embedding_layer = self.model.get_input_embeddings()
handle = embedding_layer.register_forward_hook(forward_hook)
return handle
def _smooth_grads(self, batch):
total_gradients = None
for _ in range(self.num_samples):
handle = self._register_forward_hook(self.stdev)
grads = self._get_gradients(batch)
handle.remove()
# Sum gradients
if total_gradients is None:
total_gradients = grads
else:
total_gradients = total_gradients + grads
total_gradients /= self.num_samples
self.batch_output.append(total_gradients)