Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Nov 26, 2023
1 parent 9ab4cd9 commit d9e8695
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 5 deletions.
10 changes: 6 additions & 4 deletions analog/analysis/influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def parse_config(self):
return

@torch.no_grad()
def precondition(self, src, damping=0.0):
def precondition(self, src, damping=None):
preconditioned = {}
(
hessian_eigval,
Expand All @@ -32,6 +32,8 @@ def precondition(self, src, damping=0.0):
if is_ekfac
else torch.outer(module_eigval["backward"], module_eigval["forward"])
)
if damping is None:
damping = 0.1 * torch.mean(scale)
prec_rotated_grad = rotated_grad / (scale + damping)
preconditioned[module_name] = einsum(
module_eigvec["backward"],
Expand All @@ -42,7 +44,7 @@ def precondition(self, src, damping=0.0):
return preconditioned

@torch.no_grad()
def compute_influence(self, src, tgt, preconditioned=False, damping=0.0):
def compute_influence(self, src, tgt, preconditioned=False, damping=None):
if not preconditioned:
src = self.precondition(src, damping)

Expand All @@ -58,7 +60,7 @@ def compute_influence(self, src, tgt, preconditioned=False, damping=0.0):
total_influence += module_influence.squeeze()
return total_influence

def compute_self_influence(self, src, damping=0.0):
def compute_self_influence(self, src, damping=None):
src_pc = self.precondition(src, damping)
total_influence = 0.0
for module_name in src_pc.keys():
Expand All @@ -67,7 +69,7 @@ def compute_self_influence(self, src, damping=0.0):
total_influence += module_influence.squeeze()
return total_influence

def compute_influence_all(self, src, loader, damping=0.0):
def compute_influence_all(self, src, loader, damping=None):
if_scores = []
src = self.precondition(src, damping)
for tgt_ids, tgt in loader:
Expand Down
2 changes: 1 addition & 1 deletion examples/bert_influence/compute_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
# Hessian logging
analog.watch(model)
analog_kwargs = {"log": [], "hessian": True, "save": False}
id_gen = DataIDGenerator()
id_gen = DataIDGenerator(mode="index")
for epoch in range(2):
for batch in tqdm(eval_train_loader, desc="Hessian logging"):
data_id = id_gen(batch["input_ids"])
Expand Down
29 changes: 29 additions & 0 deletions examples/bert_influence/qualitative_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch

from utils import get_loaders, set_seed

set_seed(0)

# data
_, eval_train_loader, test_loader = get_loaders(
data_name="sst2",
valid_indices=list(range(32)),
)

# score
score_path = "if_analog.pt"
scores = torch.load(score_path, map_location="cpu")
print(scores.shape)

for i in range(16):
print("=" * 80)
print(f"{i}th data point")
print(f"Sequence: {test_loader.dataset[i]['sentence']}")
print(f"Label: {test_loader.dataset[i]['label']}")

print("Most influential data point")
rank = torch.argsort(scores[i], descending=True)
for j in range(3):
print(f"Rank {j} (score = {scores[i][rank[j]]})")
print(f"Sentence: {eval_train_loader.dataset[int(rank[j])]['sentence']}")
print(f"Label: {eval_train_loader.dataset[int(rank[j])]['label']}")

0 comments on commit d9e8695

Please sign in to comment.