From d9e8695ec619e635f0b059e2ee5f3861d2bdfc19 Mon Sep 17 00:00:00 2001 From: sangkeun00 Date: Sun, 26 Nov 2023 17:05:29 -0500 Subject: [PATCH] minor updates --- analog/analysis/influence_function.py | 10 ++++--- examples/bert_influence/compute_influence.py | 2 +- .../bert_influence/qualitative_analysis.py | 29 +++++++++++++++++++ 3 files changed, 36 insertions(+), 5 deletions(-) create mode 100644 examples/bert_influence/qualitative_analysis.py diff --git a/analog/analysis/influence_function.py b/analog/analysis/influence_function.py index 5dbe18aa..faa6ea50 100644 --- a/analog/analysis/influence_function.py +++ b/analog/analysis/influence_function.py @@ -10,7 +10,7 @@ def parse_config(self): return @torch.no_grad() - def precondition(self, src, damping=0.0): + def precondition(self, src, damping=None): preconditioned = {} ( hessian_eigval, @@ -32,6 +32,8 @@ def precondition(self, src, damping=0.0): if is_ekfac else torch.outer(module_eigval["backward"], module_eigval["forward"]) ) + if damping is None: + damping = 0.1 * torch.mean(scale) prec_rotated_grad = rotated_grad / (scale + damping) preconditioned[module_name] = einsum( module_eigvec["backward"], @@ -42,7 +44,7 @@ def precondition(self, src, damping=0.0): return preconditioned @torch.no_grad() - def compute_influence(self, src, tgt, preconditioned=False, damping=0.0): + def compute_influence(self, src, tgt, preconditioned=False, damping=None): if not preconditioned: src = self.precondition(src, damping) @@ -58,7 +60,7 @@ def compute_influence(self, src, tgt, preconditioned=False, damping=0.0): total_influence += module_influence.squeeze() return total_influence - def compute_self_influence(self, src, damping=0.0): + def compute_self_influence(self, src, damping=None): src_pc = self.precondition(src, damping) total_influence = 0.0 for module_name in src_pc.keys(): @@ -67,7 +69,7 @@ def compute_self_influence(self, src, damping=0.0): total_influence += module_influence.squeeze() return total_influence - def compute_influence_all(self, src, loader, damping=0.0): + def compute_influence_all(self, src, loader, damping=None): if_scores = [] src = self.precondition(src, damping) for tgt_ids, tgt in loader: diff --git a/examples/bert_influence/compute_influence.py b/examples/bert_influence/compute_influence.py index 40607f97..fa59b9b0 100644 --- a/examples/bert_influence/compute_influence.py +++ b/examples/bert_influence/compute_influence.py @@ -38,7 +38,7 @@ # Hessian logging analog.watch(model) analog_kwargs = {"log": [], "hessian": True, "save": False} -id_gen = DataIDGenerator() +id_gen = DataIDGenerator(mode="index") for epoch in range(2): for batch in tqdm(eval_train_loader, desc="Hessian logging"): data_id = id_gen(batch["input_ids"]) diff --git a/examples/bert_influence/qualitative_analysis.py b/examples/bert_influence/qualitative_analysis.py new file mode 100644 index 00000000..854a413a --- /dev/null +++ b/examples/bert_influence/qualitative_analysis.py @@ -0,0 +1,29 @@ +import torch + +from utils import get_loaders, set_seed + +set_seed(0) + +# data +_, eval_train_loader, test_loader = get_loaders( + data_name="sst2", + valid_indices=list(range(32)), +) + +# score +score_path = "if_analog.pt" +scores = torch.load(score_path, map_location="cpu") +print(scores.shape) + +for i in range(16): + print("=" * 80) + print(f"{i}th data point") + print(f"Sequence: {test_loader.dataset[i]['sentence']}") + print(f"Label: {test_loader.dataset[i]['label']}") + + print("Most influential data point") + rank = torch.argsort(scores[i], descending=True) + for j in range(3): + print(f"Rank {j} (score = {scores[i][rank[j]]})") + print(f"Sentence: {eval_train_loader.dataset[int(rank[j])]['sentence']}") + print(f"Label: {eval_train_loader.dataset[int(rank[j])]['label']}")