Skip to content

Commit

Permalink
Fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Mar 22, 2024
1 parent 6b90e6e commit c7a8cdf
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/wikitext/evaluate_lds.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def main():
mask = torch.from_numpy(results["mask"]).float()
mask = ((mask + 1) % 2).to(dtype=torch.float64).t()

# The path might need to get fixed.
# You might need to change the path.
scores = Analyzer.load_file("scores_pairwise/ekfac_pairwise.safetensors")["all_modules"].to(dtype=torch.float64)
preds = (scores @ mask).t().numpy()

Expand Down
5 changes: 3 additions & 2 deletions examples/wikitext/run_counterfactual.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def main():
logging.basicConfig(level=logging.INFO)

train_dataset = get_wikitext_dataset(split="train")
# You might need to change the path.
scores = Analyzer.load_file("analyses/wikitext/scores_ekfac_pairwise/pairwise_scores.safetensors")["all_modules"][
:50
].sum(dim=0)
Expand All @@ -26,7 +27,7 @@ def get_topk_keep_indices(current_score: torch.Tensor, topk: int = 1) -> List[in
remove_indices = [tensor.item() for tensor in remove_indices]
return list(set(list(range(len(train_dataset)))) - set(remove_indices))

eval_train_dataset = get_wikitext_dataset(split="eval_train", indices=list(range(50)))
eval_train_dataset = get_wikitext_dataset(split="valid", indices=list(range(50)))

def train_and_evaluate(indices):
train_dataset = get_wikitext_dataset(split="train", indices=indices)
Expand All @@ -37,7 +38,7 @@ def train_and_evaluate(indices):
learning_rate=3e-05,
weight_decay=0.01,
)
return evaluate_model(model, eval_train_dataset, 1)
return evaluate_model(model, eval_train_dataset, batch_size=16)

num_iter = 1
topk_lst = [0, 50, 100, 150, 200]
Expand Down

0 comments on commit c7a8cdf

Please sign in to comment.