From ecbdfe00cb64718425e1280ba10198a4bcf6e5f3 Mon Sep 17 00:00:00 2001 From: Anthony Duong Date: Mon, 3 Jun 2024 22:35:56 -0700 Subject: [PATCH 1/4] fixes demos pip install from unfound repos --- demos/Activation_Patching_in_TL_Demo.ipynb | 2 +- demos/BERT.ipynb | 2 +- demos/Grokking_Demo.ipynb | 4 ++-- demos/Head_Detector_Demo.ipynb | 4 ++-- demos/Othello_GPT.ipynb | 2 +- demos/Qwen.ipynb | 2 +- demos/SVD_Interpreter_Demo.ipynb | 4 ++-- demos/Santa_Coder.ipynb | 2 +- demos/Tracr_to_Transformer_Lens_Demo.ipynb | 2 +- 9 files changed, 12 insertions(+), 12 deletions(-) diff --git a/demos/Activation_Patching_in_TL_Demo.ipynb b/demos/Activation_Patching_in_TL_Demo.ipynb index e94768020..3be728cb1 100644 --- a/demos/Activation_Patching_in_TL_Demo.ipynb +++ b/demos/Activation_Patching_in_TL_Demo.ipynb @@ -60,7 +60,7 @@ " print(\"Running as a Colab notebook\")\n", " %pip install git+https://github.com/TransformerLensOrg/TransformerLens.git\n", " # Install my janky personal plotting utils\n", - " %pip install git+https://github.com/TransformerLensOrg/neel-plotly.git\n", + " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", "except:\n", " IN_COLAB = False\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb index 5c2c96c11..d086fed8c 100644 --- a/demos/BERT.ipynb +++ b/demos/BERT.ipynb @@ -72,7 +72,7 @@ " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", " # # Install another version of node that makes PySvelte work way faster\n", " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", - " # %pip install git+https://github.com/TransformerLensOrg/PySvelte.git\n", + " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", "except:\n", " IN_COLAB = False\n", "\n", diff --git a/demos/Grokking_Demo.ipynb b/demos/Grokking_Demo.ipynb index b0caedf26..26049675e 100644 --- a/demos/Grokking_Demo.ipynb +++ b/demos/Grokking_Demo.ipynb @@ -65,7 +65,7 @@ " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", " # # Install another version of node that makes PySvelte work way faster\n", " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", - " # %pip install git+https://github.com/TransformerLensOrg/PySvelte.git\n", + " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", "except:\n", " IN_COLAB = False\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", @@ -932,7 +932,7 @@ } ], "source": [ - "%pip install git+https://github.com/TransformerLensOrg/neel-plotly.git \n", + "%pip install git+https://github.com/neelnanda-io/neel-plotly.git \n", "from neel_plotly.plot import line\n", "line([train_losses[::100], test_losses[::100]], x=np.arange(0, len(train_losses), 100), xaxis=\"Epoch\", yaxis=\"Loss\", log_y=True, title=\"Training Curve for Modular Addition\", line_labels=['train', 'test'], toggle_x=True, toggle_y=True)" ] diff --git a/demos/Head_Detector_Demo.ipynb b/demos/Head_Detector_Demo.ipynb index 4f8ccf5bc..af55fcaeb 100644 --- a/demos/Head_Detector_Demo.ipynb +++ b/demos/Head_Detector_Demo.ipynb @@ -318,10 +318,10 @@ "if IN_COLAB or IN_GITHUB:\n", " %pip install git+https://github.com/TransformerLensOrg/TransformerLens.git\n", " # Install Neel's personal plotting utils\n", - " %pip install git+https://github.com/TransformerLensOrg/neel-plotly.git\n", + " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", " # Install another version of node that makes PySvelte work way faster\n", " !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", - " %pip install git+https://github.com/TransformerLensOrg/PySvelte.git\n", + " %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", " # Needed for PySvelte to work, v3 came out and broke things...\n", " %pip install typeguard==2.13.3\n", " %pip install typing-extensions" diff --git a/demos/Othello_GPT.ipynb b/demos/Othello_GPT.ipynb index 4f56cb00a..191785ae9 100644 --- a/demos/Othello_GPT.ipynb +++ b/demos/Othello_GPT.ipynb @@ -203,7 +203,7 @@ " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", " # # Install another version of node that makes PySvelte work way faster\n", " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", - " # %pip install git+https://github.com/TransformerLensOrg/PySvelte.git\n", + " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", "except:\n", " IN_COLAB = False\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", diff --git a/demos/Qwen.ipynb b/demos/Qwen.ipynb index bcde85da4..fba5144ae 100644 --- a/demos/Qwen.ipynb +++ b/demos/Qwen.ipynb @@ -117,7 +117,7 @@ " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", " # # Install another version of node that makes PySvelte work way faster\n", " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", - " # %pip install git+https://github.com/TransformerLensOrg/PySvelte.git\n", + " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", "except:\n", " IN_COLAB = False\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", diff --git a/demos/SVD_Interpreter_Demo.ipynb b/demos/SVD_Interpreter_Demo.ipynb index a1669deb6..0f1c38021 100644 --- a/demos/SVD_Interpreter_Demo.ipynb +++ b/demos/SVD_Interpreter_Demo.ipynb @@ -71,10 +71,10 @@ " print(\"Running as a Colab notebook\")\n", " %pip install git+https://github.com/JayBaileyCS/TransformerLens.git # TODO: Change!\n", " # Install Neel's personal plotting utils\n", - " %pip install git+https://github.com/TransformerLensOrg/neel-plotly.git\n", + " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", " # Install another version of node that makes PySvelte work way faster\n", " !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", - " %pip install git+https://github.com/TransformerLensOrg/PySvelte.git\n", + " %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", " # Needed for PySvelte to work, v3 came out and broke things...\n", " %pip install typeguard==2.13.3\n", " %pip install typing-extensions\n", diff --git a/demos/Santa_Coder.ipynb b/demos/Santa_Coder.ipynb index 936d2ffff..0c95abd1d 100644 --- a/demos/Santa_Coder.ipynb +++ b/demos/Santa_Coder.ipynb @@ -40,7 +40,7 @@ " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", " # # Install another version of node that makes PySvelte work way faster\n", " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", - " # %pip install git+https://github.com/TransformerLensOrg/PySvelte.git\n", + " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", "except:\n", " IN_COLAB = False\n", " print(\"Running as a Jupyter notebook - intended for development only!\")\n", diff --git a/demos/Tracr_to_Transformer_Lens_Demo.ipynb b/demos/Tracr_to_Transformer_Lens_Demo.ipynb index b55e57b8a..bb1bb54dd 100644 --- a/demos/Tracr_to_Transformer_Lens_Demo.ipynb +++ b/demos/Tracr_to_Transformer_Lens_Demo.ipynb @@ -65,7 +65,7 @@ " print(\"Running as a Colab notebook\")\n", " %pip install transformer_lens\n", " # Fork of Tracr that's backward compatible with Python 3.8\n", - " %pip install git+https://github.com/TransformerLensOrg/Tracr\n", + " %pip install git+https://github.com/neelnanda-io/Tracr\n", " \n", "except:\n", " IN_COLAB = False\n", From 7c913837a7b6f5dd41576a9c86156f729589203a Mon Sep 17 00:00:00 2001 From: Anthony Duong Date: Mon, 3 Jun 2024 22:38:50 -0700 Subject: [PATCH 2/4] fixes in Attribution_Patching_Demo.ipynb --- demos/Attribution_Patching_Demo.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demos/Attribution_Patching_Demo.ipynb b/demos/Attribution_Patching_Demo.ipynb index de7c7088b..2862fb9c8 100644 --- a/demos/Attribution_Patching_Demo.ipynb +++ b/demos/Attribution_Patching_Demo.ipynb @@ -1 +1 @@ -{"cells":[{"cell_type":"markdown","metadata":{},"source":["\n"," \"Open\n",""]},{"cell_type":"markdown","metadata":{},"source":[" # Attribution Patching Demo\n"," **Read [the accompanying blog post here](https://neelnanda.io/attribution-patching) for more context**\n"," This is an interim research report, giving a whirlwind tour of some unpublished work I did at Anthropic (credit to the then team - Chris Olah, Catherine Olsson, Nelson Elhage and Tristan Hume for help, support, and mentorship!)\n","\n"," The goal of this work is run activation patching at an industrial scale, by using gradient based attribution to approximate the technique - allow an arbitrary number of patches to be made on two forwards and a single backward pass\n","\n"," I have had less time than hoped to flesh out this investigation, but am writing up a rough investigation and comparison to standard activation patching on a few tasks to give a sense of the potential of this approach, and where it works vs falls down."]},{"cell_type":"markdown","metadata":{},"source":[" To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n","\n"," **Tips for reading this Colab:**\n"," * You can run all this code for yourself!\n"," * The graphs are interactive!\n"," * Use the table of contents pane in the sidebar to navigate\n"," * Collapse irrelevant sections with the dropdown arrows\n"," * Search the page using the search in the sidebar, not CTRL+F"]},{"cell_type":"markdown","metadata":{},"source":[" ## Setup (Ignore)"]},{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Running as a Jupyter notebook - intended for development only!\n"]},{"name":"stderr","output_type":"stream","text":["/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_25358/2480103146.py:24: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n"," ipython.magic(\"load_ext autoreload\")\n","/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_25358/2480103146.py:25: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n"," ipython.magic(\"autoreload 2\")\n"]}],"source":["# Janky code to do different setup when run in a Colab notebook vs VSCode\n","import os\n","\n","DEBUG_MODE = False\n","IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n","try:\n"," import google.colab\n","\n"," IN_COLAB = True\n"," print(\"Running as a Colab notebook\")\n","except:\n"," IN_COLAB = False\n"," print(\"Running as a Jupyter notebook - intended for development only!\")\n"," from IPython import get_ipython\n","\n"," ipython = get_ipython()\n"," # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n"," ipython.magic(\"load_ext autoreload\")\n"," ipython.magic(\"autoreload 2\")\n","\n","if IN_COLAB or IN_GITHUB:\n"," %pip install transformer_lens\n"," %pip install torchtyping\n"," # Install my janky personal plotting utils\n"," %pip install git+https://github.com/TransformerLensOrg/neel-plotly.git\n"," # Install another version of node that makes PySvelte work way faster\n"," %pip install circuitsvis\n"," # Needed for PySvelte to work, v3 came out and broke things...\n"," %pip install typeguard==2.13.3"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n","import plotly.io as pio\n","\n","if IN_COLAB or not DEBUG_MODE:\n"," # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n"," pio.renderers.default = \"colab\"\n","else:\n"," pio.renderers.default = \"notebook_connected\""]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[{"ename":"ModuleNotFoundError","evalue":"No module named 'torchtyping'","output_type":"error","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)","Cell \u001b[0;32mIn[3], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mplotly\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mexpress\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpx\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DataLoader\n\u001b[0;32m---> 15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorchtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TensorType \u001b[38;5;28;01mas\u001b[39;00m TT\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m List, Union, Optional, Callable\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfunctools\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m partial\n","\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torchtyping'"]}],"source":["# Import stuff\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import numpy as np\n","import einops\n","from fancy_einsum import einsum\n","import tqdm.notebook as tqdm\n","import random\n","from pathlib import Path\n","import plotly.express as px\n","from torch.utils.data import DataLoader\n","\n","from torchtyping import TensorType as TT\n","from typing import List, Union, Optional, Callable\n","from functools import partial\n","import copy\n","import itertools\n","import json\n","\n","from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n","import dataclasses\n","import datasets\n","from IPython.display import HTML, Markdown"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[],"source":["import transformer_lens\n","import transformer_lens.utils as utils\n","from transformer_lens.hook_points import (\n"," HookedRootModule,\n"," HookPoint,\n",") # Hooking utilities\n","from transformer_lens import (\n"," HookedTransformer,\n"," HookedTransformerConfig,\n"," FactoredMatrix,\n"," ActivationCache,\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":["from neel_plotly import line, imshow, scatter"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["import transformer_lens.patching as patching"]},{"cell_type":"markdown","metadata":{},"source":[" ## IOI Patching Setup\n"," This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important."]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-small into HookedTransformer\n"]}],"source":["model = HookedTransformer.from_pretrained(\"gpt2-small\")\n","model.set_use_attn_result(True)"]},{"cell_type":"code","execution_count":9,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to\n","Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to\n","Answer token indices tensor([[ 5335, 1757],\n"," [ 1757, 5335],\n"," [ 4186, 3700],\n"," [ 3700, 4186],\n"," [ 6035, 15686],\n"," [15686, 6035],\n"," [ 5780, 14235],\n"," [14235, 5780]], device='cuda:0')\n"]}],"source":["prompts = [\n"," \"When John and Mary went to the shops, John gave the bag to\",\n"," \"When John and Mary went to the shops, Mary gave the bag to\",\n"," \"When Tom and James went to the park, James gave the ball to\",\n"," \"When Tom and James went to the park, Tom gave the ball to\",\n"," \"When Dan and Sid went to the shops, Sid gave an apple to\",\n"," \"When Dan and Sid went to the shops, Dan gave an apple to\",\n"," \"After Martin and Amy went to the park, Amy gave a drink to\",\n"," \"After Martin and Amy went to the park, Martin gave a drink to\",\n","]\n","answers = [\n"," (\" Mary\", \" John\"),\n"," (\" John\", \" Mary\"),\n"," (\" Tom\", \" James\"),\n"," (\" James\", \" Tom\"),\n"," (\" Dan\", \" Sid\"),\n"," (\" Sid\", \" Dan\"),\n"," (\" Martin\", \" Amy\"),\n"," (\" Amy\", \" Martin\"),\n","]\n","\n","clean_tokens = model.to_tokens(prompts)\n","# Swap each adjacent pair, with a hacky list comprehension\n","corrupted_tokens = clean_tokens[\n"," [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]\n","]\n","print(\"Clean string 0\", model.to_string(clean_tokens[0]))\n","print(\"Corrupted string 0\", model.to_string(corrupted_tokens[0]))\n","\n","answer_token_indices = torch.tensor(\n"," [\n"," [model.to_single_token(answers[i][j]) for j in range(2)]\n"," for i in range(len(answers))\n"," ],\n"," device=model.cfg.device,\n",")\n","print(\"Answer token indices\", answer_token_indices)"]},{"cell_type":"code","execution_count":10,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 3.5519\n","Corrupted logit diff: -3.5519\n"]}],"source":["def get_logit_diff(logits, answer_token_indices=answer_token_indices):\n"," if len(logits.shape) == 3:\n"," # Get final logits only\n"," logits = logits[:, -1, :]\n"," correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))\n"," incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))\n"," return (correct_logits - incorrect_logits).mean()\n","\n","\n","clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n","corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n","\n","clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()\n","print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n","\n","corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()\n","print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Baseline is 1: 1.0000\n","Corrupted Baseline is 0: 0.0000\n"]}],"source":["CLEAN_BASELINE = clean_logit_diff\n","CORRUPTED_BASELINE = corrupted_logit_diff\n","\n","\n","def ioi_metric(logits, answer_token_indices=answer_token_indices):\n"," return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (\n"," CLEAN_BASELINE - CORRUPTED_BASELINE\n"," )\n","\n","\n","print(f\"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}\")\n","print(f\"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Patching\n"," In the following cells, we define attribution patching and use it in various ways on the model."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["Metric = Callable[[TT[\"batch_and_pos_dims\", \"d_model\"]], float]"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Value: 1.0\n","Clean Activations Cached: 220\n","Clean Gradients Cached: 220\n","Corrupted Value: 0.0\n","Corrupted Activations Cached: 220\n","Corrupted Gradients Cached: 220\n"]}],"source":["filter_not_qkv_input = lambda name: \"_input\" not in name\n","\n","\n","def get_cache_fwd_and_bwd(model, tokens, metric):\n"," model.reset_hooks()\n"," cache = {}\n","\n"," def forward_cache_hook(act, hook):\n"," cache[hook.name] = act.detach()\n","\n"," model.add_hook(filter_not_qkv_input, forward_cache_hook, \"fwd\")\n","\n"," grad_cache = {}\n","\n"," def backward_cache_hook(act, hook):\n"," grad_cache[hook.name] = act.detach()\n","\n"," model.add_hook(filter_not_qkv_input, backward_cache_hook, \"bwd\")\n","\n"," value = metric(model(tokens))\n"," value.backward()\n"," model.reset_hooks()\n"," return (\n"," value.item(),\n"," ActivationCache(cache, model),\n"," ActivationCache(grad_cache, model),\n"," )\n","\n","\n","clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(\n"," model, clean_tokens, ioi_metric\n",")\n","print(\"Clean Value:\", clean_value)\n","print(\"Clean Activations Cached:\", len(clean_cache))\n","print(\"Clean Gradients Cached:\", len(clean_grad_cache))\n","corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(\n"," model, corrupted_tokens, ioi_metric\n",")\n","print(\"Corrupted Value:\", corrupted_value)\n","print(\"Corrupted Activations Cached:\", len(corrupted_cache))\n","print(\"Corrupted Gradients Cached:\", len(corrupted_grad_cache))"]},{"cell_type":"markdown","metadata":{},"source":[" ### Attention Attribution\n"," The easiest thing to start with is to not even engage with the corrupted tokens/patching, but to look at the attribution of the attention patterns - that is, the linear approximation to what happens if you set each element of the attention pattern to zero. This, as it turns out, is a good proxy to what is going on with each head!\n"," Note that this is *not* the same as what we will later do with patching. In particular, this does not set up a careful counterfactual! It's a good tool for what's generally going on in this problem, but does not control for eg stuff that systematically boosts John > Mary in general, stuff that says \"I should activate the IOI circuit\", etc. Though using logit diff as our metric *does*\n"," Each element of the batch is independent and the metric is an average logit diff, so we can analyse each batch element independently here. We'll look at the first one, and then at the average across the whole batch (note - 4 prompts have indirect object before subject, 4 prompts have it the other way round, making the average pattern harder to interpret - I plot it over the first sequence of tokens as a mildly misleading reference).\n"," We can compare it to the interpretability in the wild diagram, and basically instantly recover most of the circuit!"]},{"cell_type":"code","execution_count":14,"metadata":{},"outputs":[],"source":["def create_attention_attr(\n"," clean_cache, clean_grad_cache\n",") -> TT[\"batch\", \"layer\", \"head_index\", \"dest\", \"src\"]:\n"," attention_stack = torch.stack(\n"," [clean_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," attention_grad_stack = torch.stack(\n"," [clean_grad_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," attention_attr = attention_grad_stack * attention_stack\n"," attention_attr = einops.rearrange(\n"," attention_attr,\n"," \"layer batch head_index dest src -> batch layer head_index dest src\",\n"," )\n"," return attention_attr\n","\n","\n","attention_attr = create_attention_attr(clean_cache, clean_grad_cache)"]},{"cell_type":"code","execution_count":15,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4']\n","['L0H0+', 'L0H0-', 'L0H1+', 'L0H1-', 'L0H2+']\n","['L0H0Q', 'L0H0K', 'L0H0V', 'L0H1Q', 'L0H1K']\n"]}],"source":["HEAD_NAMES = [\n"," f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)\n","]\n","HEAD_NAMES_SIGNED = [f\"{name}{sign}\" for name in HEAD_NAMES for sign in [\"+\", \"-\"]]\n","HEAD_NAMES_QKV = [\n"," f\"{name}{act_name}\" for name in HEAD_NAMES for act_name in [\"Q\", \"K\", \"V\"]\n","]\n","print(HEAD_NAMES[:5])\n","print(HEAD_NAMES_SIGNED[:5])\n","print(HEAD_NAMES_QKV[:5])"]},{"cell_type":"markdown","metadata":{},"source":[" An extremely janky way to plot the attention attribution patterns. We scale them to be in [-1, 1], split each head into a positive and negative part (so all of it is in [0, 1]), and then plot the top 20 head-halves (a head can appear twice!) by the max value of the attribution pattern."]},{"cell_type":"code","execution_count":16,"metadata":{},"outputs":[{"data":{"text/markdown":["### Attention Attribution for first sequence"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["### Summed Attention Attribution for all sequences"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\n"]}],"source":["def plot_attention_attr(attention_attr, tokens, top_k=20, index=0, title=\"\"):\n"," if len(tokens.shape) == 2:\n"," tokens = tokens[index]\n"," if len(attention_attr.shape) == 5:\n"," attention_attr = attention_attr[index]\n"," attention_attr_pos = attention_attr.clamp(min=-1e-5)\n"," attention_attr_neg = -attention_attr.clamp(max=1e-5)\n"," attention_attr_signed = torch.stack([attention_attr_pos, attention_attr_neg], dim=0)\n"," attention_attr_signed = einops.rearrange(\n"," attention_attr_signed,\n"," \"sign layer head_index dest src -> (layer head_index sign) dest src\",\n"," )\n"," attention_attr_signed = attention_attr_signed / attention_attr_signed.max()\n"," attention_attr_indices = (\n"," attention_attr_signed.max(-1).values.max(-1).values.argsort(descending=True)\n"," )\n"," # print(attention_attr_indices.shape)\n"," # print(attention_attr_indices)\n"," attention_attr_signed = attention_attr_signed[attention_attr_indices, :, :]\n"," head_labels = [HEAD_NAMES_SIGNED[i.item()] for i in attention_attr_indices]\n","\n"," if title:\n"," display(Markdown(\"### \" + title))\n"," display(\n"," pysvelte.AttentionMulti(\n"," tokens=model.to_str_tokens(tokens),\n"," attention=attention_attr_signed.permute(1, 2, 0)[:, :, :top_k],\n"," head_labels=head_labels[:top_k],\n"," )\n"," )\n","\n","\n","plot_attention_attr(\n"," attention_attr,\n"," clean_tokens,\n"," index=0,\n"," title=\"Attention Attribution for first sequence\",\n",")\n","\n","plot_attention_attr(\n"," attention_attr.sum(0),\n"," clean_tokens[0],\n"," title=\"Summed Attention Attribution for all sequences\",\n",")\n","print(\n"," \"Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\"\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Attribution Patching\n"," In the following sections, I will implement various kinds of attribution patching, and then compare them to the activation patching patterns (activation patching code copied from [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo))\n"," ### Residual Stream Patching\n","
Note: We add up across both d_model and batch (Explanation).\n"," We add up along d_model because we're taking the dot product - the derivative *is* the linear map that locally linearly approximates the metric, and so we take the dot product of our change vector with the derivative vector. Equivalent, we look at the effect of changing each coordinate independently, and then combine them by adding it up - it's linear, so this totally works.\n"," We add up across batch because we're taking the average of the metric, so each individual batch element provides `1/batch_size` of the overall effect. Because each batch element is independent of the others and no information moves between activations for different inputs, the batched version is equivalent to doing attribution patching separately for each input, and then averaging - in this second version the metric per input is *not* divided by batch_size because we don't average.
"]},{"cell_type":"code","execution_count":17,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_residual(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"pos\"]:\n"," clean_residual, residual_labels = clean_cache.accumulated_resid(\n"," -1, incl_mid=True, return_labels=True\n"," )\n"," corrupted_residual = corrupted_cache.accumulated_resid(\n"," -1, incl_mid=True, return_labels=False\n"," )\n"," corrupted_grad_residual = corrupted_grad_cache.accumulated_resid(\n"," -1, incl_mid=True, return_labels=False\n"," )\n"," residual_attr = einops.reduce(\n"," corrupted_grad_residual * (clean_residual - corrupted_residual),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\",\n"," )\n"," return residual_attr, residual_labels\n","\n","\n","residual_attr, residual_labels = attr_patch_residual(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," residual_attr,\n"," y=residual_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=\"Residual Attribution Patching\",\n",")\n","\n","# ### Layer Output Patching"]},{"cell_type":"code","execution_count":18,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_layer_out(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"pos\"]:\n"," clean_layer_out, labels = clean_cache.decompose_resid(-1, return_labels=True)\n"," corrupted_layer_out = corrupted_cache.decompose_resid(-1, return_labels=False)\n"," corrupted_grad_layer_out = corrupted_grad_cache.decompose_resid(\n"," -1, return_labels=False\n"," )\n"," layer_out_attr = einops.reduce(\n"," corrupted_grad_layer_out * (clean_layer_out - corrupted_layer_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\",\n"," )\n"," return layer_out_attr, labels\n","\n","\n","layer_out_attr, layer_out_labels = attr_patch_layer_out(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," layer_out_attr,\n"," y=layer_out_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=\"Layer Output Attribution Patching\",\n",")"]},{"cell_type":"code","execution_count":19,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_head_out(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_out = clean_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_head_out = corrupted_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_grad_head_out = corrupted_grad_cache.stack_head_results(\n"," -1, return_labels=False\n"," )\n"," head_out_attr = einops.reduce(\n"," corrupted_grad_head_out * (clean_head_out - corrupted_head_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\",\n"," )\n"," return head_out_attr, labels\n","\n","\n","head_out_attr, head_out_labels = attr_patch_head_out(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," head_out_attr,\n"," y=head_out_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=\"Head Output Attribution Patching\",\n",")\n","sum_head_out_attr = einops.reduce(\n"," head_out_attr,\n"," \"(layer head) pos -> layer head\",\n"," \"sum\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n",")\n","imshow(\n"," sum_head_out_attr,\n"," yaxis=\"Layer\",\n"," xaxis=\"Head Index\",\n"," title=\"Head Output Attribution Patching Sum Over Pos\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ### Head Activation Patching\n"," Intuitively, a head has three inputs, keys, queries and values. We can patch each of these individually to get a sense for where the important part of each head's input comes from!\n"," As a sanity check, we also do this for the mixed value. The result is a linear map of this (`z @ W_O == result`), so this is the same as patching the output of the head.\n"," We plot both the patch for each head over each position, and summed over position (it tends to be pretty sparse, so the latter is the same)"]},{"cell_type":"code","execution_count":20,"metadata":{},"outputs":[{"data":{"text/markdown":["#### Key Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Query Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Mixed Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","\n","\n","def stack_head_vector_from_cache(\n"," cache, activation_name: Literal[\"q\", \"k\", \"v\", \"z\"]\n",") -> TT[\"layer_and_head_index\", \"batch\", \"pos\", \"d_head\"]:\n"," \"\"\"Stacks the head vectors from the cache from a specific activation (key, query, value or mixed_value (z)) into a single tensor.\"\"\"\n"," stacked_head_vectors = torch.stack(\n"," [cache[activation_name, l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," stacked_head_vectors = einops.rearrange(\n"," stacked_head_vectors,\n"," \"layer batch pos head_index d_head -> (layer head_index) batch pos d_head\",\n"," )\n"," return stacked_head_vectors\n","\n","\n","def attr_patch_head_vector(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n"," activation_name: Literal[\"q\", \"k\", \"v\", \"z\"],\n",") -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_vector = stack_head_vector_from_cache(clean_cache, activation_name)\n"," corrupted_head_vector = stack_head_vector_from_cache(\n"," corrupted_cache, activation_name\n"," )\n"," corrupted_grad_head_vector = stack_head_vector_from_cache(\n"," corrupted_grad_cache, activation_name\n"," )\n"," head_vector_attr = einops.reduce(\n"," corrupted_grad_head_vector * (clean_head_vector - corrupted_head_vector),\n"," \"component batch pos d_head -> component pos\",\n"," \"sum\",\n"," )\n"," return head_vector_attr, labels\n","\n","\n","head_vector_attr_dict = {}\n","for activation_name, activation_name_full in [\n"," (\"k\", \"Key\"),\n"," (\"q\", \"Query\"),\n"," (\"v\", \"Value\"),\n"," (\"z\", \"Mixed Value\"),\n","]:\n"," display(Markdown(f\"#### {activation_name_full} Head Vector Attribution Patching\"))\n"," head_vector_attr_dict[activation_name], head_vector_labels = attr_patch_head_vector(\n"," clean_cache, corrupted_cache, corrupted_grad_cache, activation_name\n"," )\n"," imshow(\n"," head_vector_attr_dict[activation_name],\n"," y=head_vector_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=f\"{activation_name_full} Attribution Patching\",\n"," )\n"," sum_head_vector_attr = einops.reduce(\n"," head_vector_attr_dict[activation_name],\n"," \"(layer head) pos -> layer head\",\n"," \"sum\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," )\n"," imshow(\n"," sum_head_vector_attr,\n"," yaxis=\"Layer\",\n"," xaxis=\"Head Index\",\n"," title=f\"{activation_name_full} Attribution Patching Sum Over Pos\",\n"," )"]},{"cell_type":"code","execution_count":21,"metadata":{},"outputs":[{"data":{"text/markdown":["### Head Pattern Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","\n","\n","def stack_head_pattern_from_cache(\n"," cache,\n",") -> TT[\"layer_and_head_index\", \"batch\", \"dest_pos\", \"src_pos\"]:\n"," \"\"\"Stacks the head patterns from the cache into a single tensor.\"\"\"\n"," stacked_head_pattern = torch.stack(\n"," [cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," stacked_head_pattern = einops.rearrange(\n"," stacked_head_pattern,\n"," \"layer batch head_index dest_pos src_pos -> (layer head_index) batch dest_pos src_pos\",\n"," )\n"," return stacked_head_pattern\n","\n","\n","def attr_patch_head_pattern(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"dest_pos\", \"src_pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_pattern = stack_head_pattern_from_cache(clean_cache)\n"," corrupted_head_pattern = stack_head_pattern_from_cache(corrupted_cache)\n"," corrupted_grad_head_pattern = stack_head_pattern_from_cache(corrupted_grad_cache)\n"," head_pattern_attr = einops.reduce(\n"," corrupted_grad_head_pattern * (clean_head_pattern - corrupted_head_pattern),\n"," \"component batch dest_pos src_pos -> component dest_pos src_pos\",\n"," \"sum\",\n"," )\n"," return head_pattern_attr, labels\n","\n","\n","head_pattern_attr, labels = attr_patch_head_pattern(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","\n","plot_attention_attr(\n"," einops.rearrange(\n"," head_pattern_attr,\n"," \"(layer head) dest src -> layer head dest src\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," ),\n"," clean_tokens,\n"," index=0,\n"," title=\"Head Pattern Attribution Patching\",\n",")"]},{"cell_type":"code","execution_count":22,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_head_vector_grad_input_from_grad_cache(\n"," grad_cache: ActivationCache, activation_name: Literal[\"q\", \"k\", \"v\"], layer: int\n",") -> TT[\"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," vector_grad = grad_cache[activation_name, layer]\n"," ln_scales = grad_cache[\"scale\", layer, \"ln1\"]\n"," attn_layer_object = model.blocks[layer].attn\n"," if activation_name == \"q\":\n"," W = attn_layer_object.W_Q\n"," elif activation_name == \"k\":\n"," W = attn_layer_object.W_K\n"," elif activation_name == \"v\":\n"," W = attn_layer_object.W_V\n"," else:\n"," raise ValueError(\"Invalid activation name\")\n","\n"," return einsum(\n"," \"batch pos head_index d_head, batch pos, head_index d_model d_head -> batch pos head_index d_model\",\n"," vector_grad,\n"," ln_scales.squeeze(-1),\n"," W,\n"," )\n","\n","\n","def get_stacked_head_vector_grad_input(\n"," grad_cache, activation_name: Literal[\"q\", \"k\", \"v\"]\n",") -> TT[\"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack(\n"," [\n"," get_head_vector_grad_input_from_grad_cache(grad_cache, activation_name, l)\n"," for l in range(model.cfg.n_layers)\n"," ],\n"," dim=0,\n"," )\n","\n","\n","def get_full_vector_grad_input(\n"," grad_cache,\n",") -> TT[\"qkv\", \"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack(\n"," [\n"," get_stacked_head_vector_grad_input(grad_cache, activation_name)\n"," for activation_name in [\"q\", \"k\", \"v\"]\n"," ],\n"," dim=0,\n"," )\n","\n","\n","def attr_patch_head_path(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"qkv\", \"dest_component\", \"src_component\", \"pos\"]:\n"," \"\"\"\n"," Computes the attribution patch along the path between each pair of heads.\n","\n"," Sets this to zero for the path from any late head to any early head\n","\n"," \"\"\"\n"," start_labels = HEAD_NAMES\n"," end_labels = HEAD_NAMES_QKV\n"," full_vector_grad_input = get_full_vector_grad_input(corrupted_grad_cache)\n"," clean_head_result_stack = clean_cache.stack_head_results(-1)\n"," corrupted_head_result_stack = corrupted_cache.stack_head_results(-1)\n"," diff_head_result = einops.rearrange(\n"," clean_head_result_stack - corrupted_head_result_stack,\n"," \"(layer head_index) batch pos d_model -> layer batch pos head_index d_model\",\n"," layer=model.cfg.n_layers,\n"," head_index=model.cfg.n_heads,\n"," )\n"," path_attr = einsum(\n"," \"qkv layer_end batch pos head_end d_model, layer_start batch pos head_start d_model -> qkv layer_end head_end layer_start head_start pos\",\n"," full_vector_grad_input,\n"," diff_head_result,\n"," )\n"," correct_layer_order_mask = (\n"," torch.arange(model.cfg.n_layers)[None, :, None, None, None, None]\n"," > torch.arange(model.cfg.n_layers)[None, None, None, :, None, None]\n"," ).to(path_attr.device)\n"," zero = torch.zeros(1, device=path_attr.device)\n"," path_attr = torch.where(correct_layer_order_mask, path_attr, zero)\n","\n"," path_attr = einops.rearrange(\n"," path_attr,\n"," \"qkv layer_end head_end layer_start head_start pos -> (layer_end head_end qkv) (layer_start head_start) pos\",\n"," )\n"," return path_attr, end_labels, start_labels\n","\n","\n","head_path_attr, end_labels, start_labels = attr_patch_head_path(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," head_path_attr.sum(-1),\n"," y=end_labels,\n"," yaxis=\"Path End (Head Input)\",\n"," x=start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=\"Head Path Attribution Patching\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" This is hard to parse. Here's an experiment with filtering for the most important heads and showing their paths."]},{"cell_type":"code","execution_count":23,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["head_out_values, head_out_indices = head_out_attr.sum(-1).abs().sort(descending=True)\n","line(head_out_values)\n","top_head_indices = head_out_indices[:22].sort().values\n","top_end_indices = []\n","top_end_labels = []\n","top_start_indices = []\n","top_start_labels = []\n","for i in top_head_indices:\n"," i = i.item()\n"," top_start_indices.append(i)\n"," top_start_labels.append(start_labels[i])\n"," for j in range(3):\n"," top_end_indices.append(3 * i + j)\n"," top_end_labels.append(end_labels[3 * i + j])\n","\n","imshow(\n"," head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n"," y=top_end_labels,\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=\"Head Path Attribution Patching (Filtered for Top Heads)\",\n",")"]},{"cell_type":"code","execution_count":24,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["for j, composition_type in enumerate([\"Query\", \"Key\", \"Value\"]):\n"," imshow(\n"," head_path_attr[top_end_indices, :][:, top_start_indices][j::3].sum(-1),\n"," y=top_end_labels[j::3],\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=f\"Head Path to {composition_type} Attribution Patching (Filtered for Top Heads)\",\n"," )"]},{"cell_type":"code","execution_count":25,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["top_head_path_attr = einops.rearrange(\n"," head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n"," \"(head_end qkv) head_start -> qkv head_end head_start\",\n"," qkv=3,\n",")\n","imshow(\n"," top_head_path_attr,\n"," y=[i[:-1] for i in top_end_labels[::3]],\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=f\"Head Path Attribution Patching (Filtered for Top Heads)\",\n"," facet_col=0,\n"," facet_labels=[\"Query\", \"Key\", \"Value\"],\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" Let's now dive into 3 interesting heads: L5H5 (induction head), L8H6 (S-Inhibition Head), L9H9 (Name Mover) and look at their input and output paths (note - Q input means )"]},{"cell_type":"code","execution_count":26,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["interesting_heads = [\n"," 5 * model.cfg.n_heads + 5,\n"," 8 * model.cfg.n_heads + 6,\n"," 9 * model.cfg.n_heads + 9,\n","]\n","interesting_head_labels = [HEAD_NAMES[i] for i in interesting_heads]\n","for head_index, label in zip(interesting_heads, interesting_head_labels):\n"," in_paths = head_path_attr[3 * head_index : 3 * head_index + 3].sum(-1)\n"," out_paths = head_path_attr[:, head_index].sum(-1)\n"," out_paths = einops.rearrange(out_paths, \"(layer_head qkv) -> qkv layer_head\", qkv=3)\n"," all_paths = torch.cat([in_paths, out_paths], dim=0)\n"," all_paths = einops.rearrange(\n"," all_paths,\n"," \"path_type (layer head) -> path_type layer head\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," )\n"," imshow(\n"," all_paths,\n"," facet_col=0,\n"," facet_labels=[\n"," \"Query (In)\",\n"," \"Key (In)\",\n"," \"Value (In)\",\n"," \"Query (Out)\",\n"," \"Key (Out)\",\n"," \"Value (Out)\",\n"," ],\n"," title=f\"Input and Output Paths for head {label}\",\n"," yaxis=\"Layer\",\n"," xaxis=\"Head\",\n"," )"]},{"cell_type":"markdown","metadata":{},"source":[" ## Validating Attribution vs Activation Patching\n"," Let's now compare attribution and activation patching. Generally it's a decent approximation! The main place it fails is MLP0 and the residual stream\n"," My fuzzy intuition is that attribution patching works badly for \"big\" things which are poorly modelled as linear approximations, and works well for \"small\" things which are more like incremental changes. Anything involving replacing the embedding is a \"big\" thing, which includes residual streams, and in GPT-2 small MLP0 seems to be used as an \"extended embedding\" (where later layers use MLP0's output instead of the token embedding), so I also count it as big.\n"," See more discussion in the accompanying blog post!\n"]},{"cell_type":"markdown","metadata":{},"source":[" First do some refactoring to make attribution patching more generic. We make an attribution cache, which is an ActivationCache where each element is (clean_act - corrupted_act) * corrupted_grad, so that it's the per-element attribution for each activation. Thanks to linearity, we just compute things by adding stuff up along the relevant dimensions!"]},{"cell_type":"code","execution_count":27,"metadata":{},"outputs":[],"source":["attribution_cache_dict = {}\n","for key in corrupted_grad_cache.cache_dict.keys():\n"," attribution_cache_dict[key] = corrupted_grad_cache.cache_dict[key] * (\n"," clean_cache.cache_dict[key] - corrupted_cache.cache_dict[key]\n"," )\n","attr_cache = ActivationCache(attribution_cache_dict, model)"]},{"cell_type":"markdown","metadata":{},"source":[" By block: For each head we patch the starting residual stream, attention output + MLP output"]},{"cell_type":"code","execution_count":28,"metadata":{},"outputs":[],"source":["str_tokens = model.to_str_tokens(clean_tokens[0])\n","context_length = len(str_tokens)"]},{"cell_type":"code","execution_count":29,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"95a5290e11b64b6a95ef5dd37d027c7a","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_block_act_patch_result = patching.get_act_patch_block_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","imshow(\n"," every_block_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Activation Patching Per Block\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",")"]},{"cell_type":"code","execution_count":30,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_block_every(attr_cache):\n"," resid_pre_attr = einops.reduce(\n"," attr_cache.stack_activation(\"resid_pre\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," attn_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"attn_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," mlp_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"mlp_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n","\n"," every_block_attr_patch_result = torch.stack(\n"," [resid_pre_attr, attn_out_attr, mlp_out_attr], dim=0\n"," )\n"," return every_block_attr_patch_result\n","\n","\n","every_block_attr_patch_result = get_attr_patch_block_every(attr_cache)\n","imshow(\n"," every_block_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Attribution Patching Per Block\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",")"]},{"cell_type":"code","execution_count":31,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_block_attr_patch_result.reshape(3, -1),\n"," x=every_block_act_patch_result.reshape(3, -1),\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Attribution vs Activation Patching Per Block\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," hover=[\n"," f\"Layer {l}, Position {p}, |{str_tokens[p]}|\"\n"," for l in range(model.cfg.n_layers)\n"," for p in range(context_length)\n"," ],\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers), \"layer -> (layer pos)\", pos=context_length\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow."]},{"cell_type":"code","execution_count":32,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"18b2e6b0985b40cd8c0cd1a16ba62975","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","imshow(\n"," every_head_all_pos_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Activation Patching Per Head (All Pos)\",\n"," xaxis=\"Head\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n",")"]},{"cell_type":"code","execution_count":33,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_all_pos_every(attr_cache):\n"," head_out_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_q_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_k_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_v_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack(\n"," [\n"," head_out_all_pos_attr,\n"," head_q_all_pos_attr,\n"," head_k_all_pos_attr,\n"," head_v_all_pos_attr,\n"," head_pattern_all_pos_attr,\n"," ]\n"," )\n","\n","\n","every_head_all_pos_attr_patch_result = get_attr_patch_attn_head_all_pos_every(\n"," attr_cache\n",")\n","imshow(\n"," every_head_all_pos_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution Patching Per Head (All Pos)\",\n"," xaxis=\"Head\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n",")"]},{"cell_type":"code","execution_count":34,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_head_all_pos_attr_patch_result.reshape(5, -1),\n"," x=every_head_all_pos_act_patch_result.reshape(5, -1),\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution vs Activation Patching Per Head (All Pos)\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," include_diag=True,\n"," hover=head_out_labels,\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers),\n"," \"layer -> (layer head)\",\n"," head=model.cfg.n_heads,\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" We see pretty good results in general, but significant errors for heads L5H5 on query and moderate errors for head L10H7 on query and key, and moderate errors for head L11H10 on key. But each of these is fine for pattern and output. My guess is that the problem is that these have pretty saturated attention on a single token, and the linear approximation is thus not great on the attention calculation here, but I'm not sure. When we plot the attention patterns, we do see this!\n"," Note that the axis labels are for the *first* prompt's tokens, but each facet is a different prompt, so this is somewhat inaccurate. In particular, every odd facet has indirect object and subject in the opposite order (IO first). But otherwise everything lines up between the prompts"]},{"cell_type":"code","execution_count":35,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["graph_tok_labels = [\n"," f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))\n","]\n","imshow(\n"," clean_cache[\"pattern\", 5][:, 5],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L5H5\",\n"," facet_name=\"Prompt\",\n",")\n","imshow(\n"," clean_cache[\"pattern\", 10][:, 7],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L10H7\",\n"," facet_name=\"Prompt\",\n",")\n","imshow(\n"," clean_cache[\"pattern\", 11][:, 10],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L11H10\",\n"," facet_name=\"Prompt\",\n",")\n","\n","\n","# [markdown]"]},{"cell_type":"code","execution_count":36,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"06f39489001845849fbc7446a07066f4","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2160 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_by_pos_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","every_head_by_pos_act_patch_result = einops.rearrange(\n"," every_head_by_pos_act_patch_result,\n"," \"act_type layer pos head -> act_type (layer head) pos\",\n",")\n","imshow(\n"," every_head_by_pos_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Activation Patching Per Head (By Pos)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer & Head\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," y=head_out_labels,\n",")"]},{"cell_type":"code","execution_count":37,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_by_pos_every(attr_cache):\n"," head_out_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_q_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_k_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_v_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer dest_pos head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack(\n"," [\n"," head_out_by_pos_attr,\n"," head_q_by_pos_attr,\n"," head_k_by_pos_attr,\n"," head_v_by_pos_attr,\n"," head_pattern_by_pos_attr,\n"," ]\n"," )\n","\n","\n","every_head_by_pos_attr_patch_result = get_attr_patch_attn_head_by_pos_every(attr_cache)\n","every_head_by_pos_attr_patch_result = einops.rearrange(\n"," every_head_by_pos_attr_patch_result,\n"," \"act_type layer pos head -> act_type (layer head) pos\",\n",")\n","imshow(\n"," every_head_by_pos_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution Patching Per Head (By Pos)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer & Head\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," y=head_out_labels,\n",")"]},{"cell_type":"code","execution_count":38,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_head_by_pos_attr_patch_result.reshape(5, -1),\n"," x=every_head_by_pos_act_patch_result.reshape(5, -1),\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution vs Activation Patching Per Head (by Pos)\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," include_diag=True,\n"," hover=[f\"{label} {tok}\" for label in head_out_labels for tok in graph_tok_labels],\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers),\n"," \"layer -> (layer head pos)\",\n"," head=model.cfg.n_heads,\n"," pos=15,\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Factual Knowledge Patching Example\n"," Incomplete, but maybe of interest!\n"," Note that I have better results with the corrupted prompt as having random words rather than Colosseum."]},{"cell_type":"code","execution_count":39,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-xl into HookedTransformer\n","Tokenized prompt: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Paris']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.73 Prob: 95.80% Token: | Paris|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.73\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m95.80\u001b[0m\u001b[1m% Token: | Paris|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.73 Prob: 95.80% Token: | Paris|\n","Top 1th token. Logit: 16.49 Prob: 1.39% Token: | E|\n","Top 2th token. Logit: 14.69 Prob: 0.23% Token: | the|\n","Top 3th token. Logit: 14.58 Prob: 0.21% Token: | É|\n","Top 4th token. Logit: 14.44 Prob: 0.18% Token: | France|\n","Top 5th token. Logit: 14.36 Prob: 0.16% Token: | Mont|\n","Top 6th token. Logit: 13.77 Prob: 0.09% Token: | Le|\n","Top 7th token. Logit: 13.66 Prob: 0.08% Token: | Ang|\n","Top 8th token. Logit: 13.43 Prob: 0.06% Token: | V|\n","Top 9th token. Logit: 13.42 Prob: 0.06% Token: | Stras|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Paris', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Paris'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Tokenized prompt: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Rome']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.02 Prob: 83.70% Token: | Rome|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.02\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m83.70\u001b[0m\u001b[1m% Token: | Rome|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.02 Prob: 83.70% Token: | Rome|\n","Top 1th token. Logit: 17.03 Prob: 4.23% Token: | Naples|\n","Top 2th token. Logit: 16.85 Prob: 3.51% Token: | Pompe|\n","Top 3th token. Logit: 16.14 Prob: 1.73% Token: | Ver|\n","Top 4th token. Logit: 15.87 Prob: 1.32% Token: | Florence|\n","Top 5th token. Logit: 14.77 Prob: 0.44% Token: | Roma|\n","Top 6th token. Logit: 14.68 Prob: 0.40% Token: | Milan|\n","Top 7th token. Logit: 14.66 Prob: 0.39% Token: | ancient|\n","Top 8th token. Logit: 14.37 Prob: 0.29% Token: | Pal|\n","Top 9th token. Logit: 14.30 Prob: 0.27% Token: | Constantinople|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Rome', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Rome'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"}],"source":["gpt2_xl = HookedTransformer.from_pretrained(\"gpt2-xl\")\n","clean_prompt = \"The Eiffel Tower is located in the city of\"\n","clean_answer = \" Paris\"\n","# corrupted_prompt = \"The red brown fox jumps is located in the city of\"\n","corrupted_prompt = \"The Colosseum is located in the city of\"\n","corrupted_answer = \" Rome\"\n","utils.test_prompt(clean_prompt, clean_answer, gpt2_xl)\n","utils.test_prompt(corrupted_prompt, corrupted_answer, gpt2_xl)"]},{"cell_type":"code","execution_count":40,"metadata":{},"outputs":[],"source":["clean_answer_index = gpt2_xl.to_single_token(clean_answer)\n","corrupted_answer_index = gpt2_xl.to_single_token(corrupted_answer)\n","\n","\n","def factual_logit_diff(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return logits[0, -1, clean_answer_index] - logits[0, -1, corrupted_answer_index]"]},{"cell_type":"code","execution_count":41,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 10.634519577026367\n","Corrupted logit diff: -8.988396644592285\n","Clean Metric: tensor(1., device='cuda:0', grad_fn=)\n","Corrupted Metric: tensor(0., device='cuda:0', grad_fn=)\n"]}],"source":["clean_logits, clean_cache = gpt2_xl.run_with_cache(clean_prompt)\n","CLEAN_LOGIT_DIFF_FACTUAL = factual_logit_diff(clean_logits).item()\n","corrupted_logits, _ = gpt2_xl.run_with_cache(corrupted_prompt)\n","CORRUPTED_LOGIT_DIFF_FACTUAL = factual_logit_diff(corrupted_logits).item()\n","\n","\n","def factual_metric(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return (factual_logit_diff(logits) - CORRUPTED_LOGIT_DIFF_FACTUAL) / (\n"," CLEAN_LOGIT_DIFF_FACTUAL - CORRUPTED_LOGIT_DIFF_FACTUAL\n"," )\n","\n","\n","print(\"Clean logit diff:\", CLEAN_LOGIT_DIFF_FACTUAL)\n","print(\"Corrupted logit diff:\", CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean Metric:\", factual_metric(clean_logits))\n","print(\"Corrupted Metric:\", factual_metric(corrupted_logits))"]},{"cell_type":"code","execution_count":42,"metadata":{},"outputs":[],"source":["# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(gpt2_xl, corrupted_prompt, factual_metric)"]},{"cell_type":"code","execution_count":43,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Corrupted: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n"]}],"source":["clean_tokens = gpt2_xl.to_tokens(clean_prompt)\n","clean_str_tokens = gpt2_xl.to_str_tokens(clean_prompt)\n","corrupted_tokens = gpt2_xl.to_tokens(corrupted_prompt)\n","corrupted_str_tokens = gpt2_xl.to_str_tokens(corrupted_prompt)\n","print(\"Clean:\", clean_str_tokens)\n","print(\"Corrupted:\", corrupted_str_tokens)"]},{"cell_type":"code","execution_count":44,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"b767eef7a3cd49b9b3cb6e5301463f08","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/48 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def act_patch_residual(clean_cache, corrupted_tokens, model: HookedTransformer, metric):\n"," if len(corrupted_tokens.shape) == 2:\n"," corrupted_tokens = corrupted_tokens[0]\n"," residual_patches = torch.zeros(\n"," (model.cfg.n_layers, len(corrupted_tokens)), device=model.cfg.device\n"," )\n","\n"," def residual_hook(resid_pre, hook, layer, pos):\n"," resid_pre[:, pos, :] = clean_cache[\"resid_pre\", layer][:, pos, :]\n"," return resid_pre\n","\n"," for layer in tqdm.tqdm(range(model.cfg.n_layers)):\n"," for pos in range(len(corrupted_tokens)):\n"," patched_logits = model.run_with_hooks(\n"," corrupted_tokens,\n"," fwd_hooks=[\n"," (\n"," f\"blocks.{layer}.hook_resid_pre\",\n"," partial(residual_hook, layer=layer, pos=pos),\n"," )\n"," ],\n"," )\n"," residual_patches[layer, pos] = metric(patched_logits).item()\n"," return residual_patches\n","\n","\n","residual_act_patch = act_patch_residual(\n"," clean_cache, corrupted_tokens, gpt2_xl, factual_metric\n",")\n","\n","imshow(\n"," residual_act_patch,\n"," title=\"Factual Recall Patching (Residual)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," x=clean_str_tokens,\n",")"]}],"metadata":{"kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.11.8"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"}}},"nbformat":4,"nbformat_minor":2} +{"cells":[{"cell_type":"markdown","metadata":{},"source":["\n"," \"Open\n",""]},{"cell_type":"markdown","metadata":{},"source":[" # Attribution Patching Demo\n"," **Read [the accompanying blog post here](https://neelnanda.io/attribution-patching) for more context**\n"," This is an interim research report, giving a whirlwind tour of some unpublished work I did at Anthropic (credit to the then team - Chris Olah, Catherine Olsson, Nelson Elhage and Tristan Hume for help, support, and mentorship!)\n","\n"," The goal of this work is run activation patching at an industrial scale, by using gradient based attribution to approximate the technique - allow an arbitrary number of patches to be made on two forwards and a single backward pass\n","\n"," I have had less time than hoped to flesh out this investigation, but am writing up a rough investigation and comparison to standard activation patching on a few tasks to give a sense of the potential of this approach, and where it works vs falls down."]},{"cell_type":"markdown","metadata":{},"source":[" To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.\n","\n"," **Tips for reading this Colab:**\n"," * You can run all this code for yourself!\n"," * The graphs are interactive!\n"," * Use the table of contents pane in the sidebar to navigate\n"," * Collapse irrelevant sections with the dropdown arrows\n"," * Search the page using the search in the sidebar, not CTRL+F"]},{"cell_type":"markdown","metadata":{},"source":[" ## Setup (Ignore)"]},{"cell_type":"code","execution_count":1,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Running as a Jupyter notebook - intended for development only!\n"]},{"name":"stderr","output_type":"stream","text":["/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_25358/2480103146.py:24: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n"," ipython.magic(\"load_ext autoreload\")\n","/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_25358/2480103146.py:25: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n"," ipython.magic(\"autoreload 2\")\n"]}],"source":["# Janky code to do different setup when run in a Colab notebook vs VSCode\n","import os\n","\n","DEBUG_MODE = False\n","IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n","try:\n"," import google.colab\n","\n"," IN_COLAB = True\n"," print(\"Running as a Colab notebook\")\n","except:\n"," IN_COLAB = False\n"," print(\"Running as a Jupyter notebook - intended for development only!\")\n"," from IPython import get_ipython\n","\n"," ipython = get_ipython()\n"," # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n"," ipython.magic(\"load_ext autoreload\")\n"," ipython.magic(\"autoreload 2\")\n","\n","if IN_COLAB or IN_GITHUB:\n"," %pip install transformer_lens\n"," %pip install torchtyping\n"," # Install my janky personal plotting utils\n"," %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n"," # Install another version of node that makes PySvelte work way faster\n"," %pip install circuitsvis\n"," # Needed for PySvelte to work, v3 came out and broke things...\n"," %pip install typeguard==2.13.3"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n","import plotly.io as pio\n","\n","if IN_COLAB or not DEBUG_MODE:\n"," # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.\n"," pio.renderers.default = \"colab\"\n","else:\n"," pio.renderers.default = \"notebook_connected\""]},{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[{"ename":"ModuleNotFoundError","evalue":"No module named 'torchtyping'","output_type":"error","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)","Cell \u001b[0;32mIn[3], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mplotly\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mexpress\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpx\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DataLoader\n\u001b[0;32m---> 15\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorchtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TensorType \u001b[38;5;28;01mas\u001b[39;00m TT\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m List, Union, Optional, Callable\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mfunctools\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m partial\n","\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torchtyping'"]}],"source":["# Import stuff\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","import torch.optim as optim\n","import numpy as np\n","import einops\n","from fancy_einsum import einsum\n","import tqdm.notebook as tqdm\n","import random\n","from pathlib import Path\n","import plotly.express as px\n","from torch.utils.data import DataLoader\n","\n","from torchtyping import TensorType as TT\n","from typing import List, Union, Optional, Callable\n","from functools import partial\n","import copy\n","import itertools\n","import json\n","\n","from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n","import dataclasses\n","import datasets\n","from IPython.display import HTML, Markdown"]},{"cell_type":"code","execution_count":5,"metadata":{},"outputs":[],"source":["import transformer_lens\n","import transformer_lens.utils as utils\n","from transformer_lens.hook_points import (\n"," HookedRootModule,\n"," HookPoint,\n",") # Hooking utilities\n","from transformer_lens import (\n"," HookedTransformer,\n"," HookedTransformerConfig,\n"," FactoredMatrix,\n"," ActivationCache,\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:"]},{"cell_type":"code","execution_count":6,"metadata":{},"outputs":[],"source":["from neel_plotly import line, imshow, scatter"]},{"cell_type":"code","execution_count":7,"metadata":{},"outputs":[],"source":["import transformer_lens.patching as patching"]},{"cell_type":"markdown","metadata":{},"source":[" ## IOI Patching Setup\n"," This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important."]},{"cell_type":"code","execution_count":8,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-small into HookedTransformer\n"]}],"source":["model = HookedTransformer.from_pretrained(\"gpt2-small\")\n","model.set_use_attn_result(True)"]},{"cell_type":"code","execution_count":9,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean string 0 <|endoftext|>When John and Mary went to the shops, John gave the bag to\n","Corrupted string 0 <|endoftext|>When John and Mary went to the shops, Mary gave the bag to\n","Answer token indices tensor([[ 5335, 1757],\n"," [ 1757, 5335],\n"," [ 4186, 3700],\n"," [ 3700, 4186],\n"," [ 6035, 15686],\n"," [15686, 6035],\n"," [ 5780, 14235],\n"," [14235, 5780]], device='cuda:0')\n"]}],"source":["prompts = [\n"," \"When John and Mary went to the shops, John gave the bag to\",\n"," \"When John and Mary went to the shops, Mary gave the bag to\",\n"," \"When Tom and James went to the park, James gave the ball to\",\n"," \"When Tom and James went to the park, Tom gave the ball to\",\n"," \"When Dan and Sid went to the shops, Sid gave an apple to\",\n"," \"When Dan and Sid went to the shops, Dan gave an apple to\",\n"," \"After Martin and Amy went to the park, Amy gave a drink to\",\n"," \"After Martin and Amy went to the park, Martin gave a drink to\",\n","]\n","answers = [\n"," (\" Mary\", \" John\"),\n"," (\" John\", \" Mary\"),\n"," (\" Tom\", \" James\"),\n"," (\" James\", \" Tom\"),\n"," (\" Dan\", \" Sid\"),\n"," (\" Sid\", \" Dan\"),\n"," (\" Martin\", \" Amy\"),\n"," (\" Amy\", \" Martin\"),\n","]\n","\n","clean_tokens = model.to_tokens(prompts)\n","# Swap each adjacent pair, with a hacky list comprehension\n","corrupted_tokens = clean_tokens[\n"," [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]\n","]\n","print(\"Clean string 0\", model.to_string(clean_tokens[0]))\n","print(\"Corrupted string 0\", model.to_string(corrupted_tokens[0]))\n","\n","answer_token_indices = torch.tensor(\n"," [\n"," [model.to_single_token(answers[i][j]) for j in range(2)]\n"," for i in range(len(answers))\n"," ],\n"," device=model.cfg.device,\n",")\n","print(\"Answer token indices\", answer_token_indices)"]},{"cell_type":"code","execution_count":10,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 3.5519\n","Corrupted logit diff: -3.5519\n"]}],"source":["def get_logit_diff(logits, answer_token_indices=answer_token_indices):\n"," if len(logits.shape) == 3:\n"," # Get final logits only\n"," logits = logits[:, -1, :]\n"," correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))\n"," incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))\n"," return (correct_logits - incorrect_logits).mean()\n","\n","\n","clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n","corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n","\n","clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()\n","print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n","\n","corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()\n","print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")"]},{"cell_type":"code","execution_count":11,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Baseline is 1: 1.0000\n","Corrupted Baseline is 0: 0.0000\n"]}],"source":["CLEAN_BASELINE = clean_logit_diff\n","CORRUPTED_BASELINE = corrupted_logit_diff\n","\n","\n","def ioi_metric(logits, answer_token_indices=answer_token_indices):\n"," return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (\n"," CLEAN_BASELINE - CORRUPTED_BASELINE\n"," )\n","\n","\n","print(f\"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}\")\n","print(f\"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}\")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Patching\n"," In the following cells, we define attribution patching and use it in various ways on the model."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["Metric = Callable[[TT[\"batch_and_pos_dims\", \"d_model\"]], float]"]},{"cell_type":"code","execution_count":13,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean Value: 1.0\n","Clean Activations Cached: 220\n","Clean Gradients Cached: 220\n","Corrupted Value: 0.0\n","Corrupted Activations Cached: 220\n","Corrupted Gradients Cached: 220\n"]}],"source":["filter_not_qkv_input = lambda name: \"_input\" not in name\n","\n","\n","def get_cache_fwd_and_bwd(model, tokens, metric):\n"," model.reset_hooks()\n"," cache = {}\n","\n"," def forward_cache_hook(act, hook):\n"," cache[hook.name] = act.detach()\n","\n"," model.add_hook(filter_not_qkv_input, forward_cache_hook, \"fwd\")\n","\n"," grad_cache = {}\n","\n"," def backward_cache_hook(act, hook):\n"," grad_cache[hook.name] = act.detach()\n","\n"," model.add_hook(filter_not_qkv_input, backward_cache_hook, \"bwd\")\n","\n"," value = metric(model(tokens))\n"," value.backward()\n"," model.reset_hooks()\n"," return (\n"," value.item(),\n"," ActivationCache(cache, model),\n"," ActivationCache(grad_cache, model),\n"," )\n","\n","\n","clean_value, clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(\n"," model, clean_tokens, ioi_metric\n",")\n","print(\"Clean Value:\", clean_value)\n","print(\"Clean Activations Cached:\", len(clean_cache))\n","print(\"Clean Gradients Cached:\", len(clean_grad_cache))\n","corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(\n"," model, corrupted_tokens, ioi_metric\n",")\n","print(\"Corrupted Value:\", corrupted_value)\n","print(\"Corrupted Activations Cached:\", len(corrupted_cache))\n","print(\"Corrupted Gradients Cached:\", len(corrupted_grad_cache))"]},{"cell_type":"markdown","metadata":{},"source":[" ### Attention Attribution\n"," The easiest thing to start with is to not even engage with the corrupted tokens/patching, but to look at the attribution of the attention patterns - that is, the linear approximation to what happens if you set each element of the attention pattern to zero. This, as it turns out, is a good proxy to what is going on with each head!\n"," Note that this is *not* the same as what we will later do with patching. In particular, this does not set up a careful counterfactual! It's a good tool for what's generally going on in this problem, but does not control for eg stuff that systematically boosts John > Mary in general, stuff that says \"I should activate the IOI circuit\", etc. Though using logit diff as our metric *does*\n"," Each element of the batch is independent and the metric is an average logit diff, so we can analyse each batch element independently here. We'll look at the first one, and then at the average across the whole batch (note - 4 prompts have indirect object before subject, 4 prompts have it the other way round, making the average pattern harder to interpret - I plot it over the first sequence of tokens as a mildly misleading reference).\n"," We can compare it to the interpretability in the wild diagram, and basically instantly recover most of the circuit!"]},{"cell_type":"code","execution_count":14,"metadata":{},"outputs":[],"source":["def create_attention_attr(\n"," clean_cache, clean_grad_cache\n",") -> TT[\"batch\", \"layer\", \"head_index\", \"dest\", \"src\"]:\n"," attention_stack = torch.stack(\n"," [clean_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," attention_grad_stack = torch.stack(\n"," [clean_grad_cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," attention_attr = attention_grad_stack * attention_stack\n"," attention_attr = einops.rearrange(\n"," attention_attr,\n"," \"layer batch head_index dest src -> batch layer head_index dest src\",\n"," )\n"," return attention_attr\n","\n","\n","attention_attr = create_attention_attr(clean_cache, clean_grad_cache)"]},{"cell_type":"code","execution_count":15,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4']\n","['L0H0+', 'L0H0-', 'L0H1+', 'L0H1-', 'L0H2+']\n","['L0H0Q', 'L0H0K', 'L0H0V', 'L0H1Q', 'L0H1K']\n"]}],"source":["HEAD_NAMES = [\n"," f\"L{l}H{h}\" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)\n","]\n","HEAD_NAMES_SIGNED = [f\"{name}{sign}\" for name in HEAD_NAMES for sign in [\"+\", \"-\"]]\n","HEAD_NAMES_QKV = [\n"," f\"{name}{act_name}\" for name in HEAD_NAMES for act_name in [\"Q\", \"K\", \"V\"]\n","]\n","print(HEAD_NAMES[:5])\n","print(HEAD_NAMES_SIGNED[:5])\n","print(HEAD_NAMES_QKV[:5])"]},{"cell_type":"markdown","metadata":{},"source":[" An extremely janky way to plot the attention attribution patterns. We scale them to be in [-1, 1], split each head into a positive and negative part (so all of it is in [0, 1]), and then plot the top 20 head-halves (a head can appear twice!) by the max value of the attribution pattern."]},{"cell_type":"code","execution_count":16,"metadata":{},"outputs":[{"data":{"text/markdown":["### Attention Attribution for first sequence"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["### Summed Attention Attribution for all sequences"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\n"]}],"source":["def plot_attention_attr(attention_attr, tokens, top_k=20, index=0, title=\"\"):\n"," if len(tokens.shape) == 2:\n"," tokens = tokens[index]\n"," if len(attention_attr.shape) == 5:\n"," attention_attr = attention_attr[index]\n"," attention_attr_pos = attention_attr.clamp(min=-1e-5)\n"," attention_attr_neg = -attention_attr.clamp(max=1e-5)\n"," attention_attr_signed = torch.stack([attention_attr_pos, attention_attr_neg], dim=0)\n"," attention_attr_signed = einops.rearrange(\n"," attention_attr_signed,\n"," \"sign layer head_index dest src -> (layer head_index sign) dest src\",\n"," )\n"," attention_attr_signed = attention_attr_signed / attention_attr_signed.max()\n"," attention_attr_indices = (\n"," attention_attr_signed.max(-1).values.max(-1).values.argsort(descending=True)\n"," )\n"," # print(attention_attr_indices.shape)\n"," # print(attention_attr_indices)\n"," attention_attr_signed = attention_attr_signed[attention_attr_indices, :, :]\n"," head_labels = [HEAD_NAMES_SIGNED[i.item()] for i in attention_attr_indices]\n","\n"," if title:\n"," display(Markdown(\"### \" + title))\n"," display(\n"," pysvelte.AttentionMulti(\n"," tokens=model.to_str_tokens(tokens),\n"," attention=attention_attr_signed.permute(1, 2, 0)[:, :, :top_k],\n"," head_labels=head_labels[:top_k],\n"," )\n"," )\n","\n","\n","plot_attention_attr(\n"," attention_attr,\n"," clean_tokens,\n"," index=0,\n"," title=\"Attention Attribution for first sequence\",\n",")\n","\n","plot_attention_attr(\n"," attention_attr.sum(0),\n"," clean_tokens[0],\n"," title=\"Summed Attention Attribution for all sequences\",\n",")\n","print(\n"," \"Note: Plotted over first sequence for reference, but pairs have IO and S1 in different positions.\"\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Attribution Patching\n"," In the following sections, I will implement various kinds of attribution patching, and then compare them to the activation patching patterns (activation patching code copied from [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo))\n"," ### Residual Stream Patching\n","
Note: We add up across both d_model and batch (Explanation).\n"," We add up along d_model because we're taking the dot product - the derivative *is* the linear map that locally linearly approximates the metric, and so we take the dot product of our change vector with the derivative vector. Equivalent, we look at the effect of changing each coordinate independently, and then combine them by adding it up - it's linear, so this totally works.\n"," We add up across batch because we're taking the average of the metric, so each individual batch element provides `1/batch_size` of the overall effect. Because each batch element is independent of the others and no information moves between activations for different inputs, the batched version is equivalent to doing attribution patching separately for each input, and then averaging - in this second version the metric per input is *not* divided by batch_size because we don't average.
"]},{"cell_type":"code","execution_count":17,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_residual(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"pos\"]:\n"," clean_residual, residual_labels = clean_cache.accumulated_resid(\n"," -1, incl_mid=True, return_labels=True\n"," )\n"," corrupted_residual = corrupted_cache.accumulated_resid(\n"," -1, incl_mid=True, return_labels=False\n"," )\n"," corrupted_grad_residual = corrupted_grad_cache.accumulated_resid(\n"," -1, incl_mid=True, return_labels=False\n"," )\n"," residual_attr = einops.reduce(\n"," corrupted_grad_residual * (clean_residual - corrupted_residual),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\",\n"," )\n"," return residual_attr, residual_labels\n","\n","\n","residual_attr, residual_labels = attr_patch_residual(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," residual_attr,\n"," y=residual_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=\"Residual Attribution Patching\",\n",")\n","\n","# ### Layer Output Patching"]},{"cell_type":"code","execution_count":18,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_layer_out(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"pos\"]:\n"," clean_layer_out, labels = clean_cache.decompose_resid(-1, return_labels=True)\n"," corrupted_layer_out = corrupted_cache.decompose_resid(-1, return_labels=False)\n"," corrupted_grad_layer_out = corrupted_grad_cache.decompose_resid(\n"," -1, return_labels=False\n"," )\n"," layer_out_attr = einops.reduce(\n"," corrupted_grad_layer_out * (clean_layer_out - corrupted_layer_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\",\n"," )\n"," return layer_out_attr, labels\n","\n","\n","layer_out_attr, layer_out_labels = attr_patch_layer_out(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," layer_out_attr,\n"," y=layer_out_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=\"Layer Output Attribution Patching\",\n",")"]},{"cell_type":"code","execution_count":19,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def attr_patch_head_out(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_out = clean_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_head_out = corrupted_cache.stack_head_results(-1, return_labels=False)\n"," corrupted_grad_head_out = corrupted_grad_cache.stack_head_results(\n"," -1, return_labels=False\n"," )\n"," head_out_attr = einops.reduce(\n"," corrupted_grad_head_out * (clean_head_out - corrupted_head_out),\n"," \"component batch pos d_model -> component pos\",\n"," \"sum\",\n"," )\n"," return head_out_attr, labels\n","\n","\n","head_out_attr, head_out_labels = attr_patch_head_out(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," head_out_attr,\n"," y=head_out_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=\"Head Output Attribution Patching\",\n",")\n","sum_head_out_attr = einops.reduce(\n"," head_out_attr,\n"," \"(layer head) pos -> layer head\",\n"," \"sum\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n",")\n","imshow(\n"," sum_head_out_attr,\n"," yaxis=\"Layer\",\n"," xaxis=\"Head Index\",\n"," title=\"Head Output Attribution Patching Sum Over Pos\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ### Head Activation Patching\n"," Intuitively, a head has three inputs, keys, queries and values. We can patch each of these individually to get a sense for where the important part of each head's input comes from!\n"," As a sanity check, we also do this for the mixed value. The result is a linear map of this (`z @ W_O == result`), so this is the same as patching the output of the head.\n"," We plot both the patch for each head over each position, and summed over position (it tends to be pretty sparse, so the latter is the same)"]},{"cell_type":"code","execution_count":20,"metadata":{},"outputs":[{"data":{"text/markdown":["#### Key Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Query Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/markdown":["#### Mixed Value Head Vector Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","\n","\n","def stack_head_vector_from_cache(\n"," cache, activation_name: Literal[\"q\", \"k\", \"v\", \"z\"]\n",") -> TT[\"layer_and_head_index\", \"batch\", \"pos\", \"d_head\"]:\n"," \"\"\"Stacks the head vectors from the cache from a specific activation (key, query, value or mixed_value (z)) into a single tensor.\"\"\"\n"," stacked_head_vectors = torch.stack(\n"," [cache[activation_name, l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," stacked_head_vectors = einops.rearrange(\n"," stacked_head_vectors,\n"," \"layer batch pos head_index d_head -> (layer head_index) batch pos d_head\",\n"," )\n"," return stacked_head_vectors\n","\n","\n","def attr_patch_head_vector(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n"," activation_name: Literal[\"q\", \"k\", \"v\", \"z\"],\n",") -> TT[\"component\", \"pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_vector = stack_head_vector_from_cache(clean_cache, activation_name)\n"," corrupted_head_vector = stack_head_vector_from_cache(\n"," corrupted_cache, activation_name\n"," )\n"," corrupted_grad_head_vector = stack_head_vector_from_cache(\n"," corrupted_grad_cache, activation_name\n"," )\n"," head_vector_attr = einops.reduce(\n"," corrupted_grad_head_vector * (clean_head_vector - corrupted_head_vector),\n"," \"component batch pos d_head -> component pos\",\n"," \"sum\",\n"," )\n"," return head_vector_attr, labels\n","\n","\n","head_vector_attr_dict = {}\n","for activation_name, activation_name_full in [\n"," (\"k\", \"Key\"),\n"," (\"q\", \"Query\"),\n"," (\"v\", \"Value\"),\n"," (\"z\", \"Mixed Value\"),\n","]:\n"," display(Markdown(f\"#### {activation_name_full} Head Vector Attribution Patching\"))\n"," head_vector_attr_dict[activation_name], head_vector_labels = attr_patch_head_vector(\n"," clean_cache, corrupted_cache, corrupted_grad_cache, activation_name\n"," )\n"," imshow(\n"," head_vector_attr_dict[activation_name],\n"," y=head_vector_labels,\n"," yaxis=\"Component\",\n"," xaxis=\"Position\",\n"," title=f\"{activation_name_full} Attribution Patching\",\n"," )\n"," sum_head_vector_attr = einops.reduce(\n"," head_vector_attr_dict[activation_name],\n"," \"(layer head) pos -> layer head\",\n"," \"sum\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," )\n"," imshow(\n"," sum_head_vector_attr,\n"," yaxis=\"Layer\",\n"," xaxis=\"Head Index\",\n"," title=f\"{activation_name_full} Attribution Patching Sum Over Pos\",\n"," )"]},{"cell_type":"code","execution_count":21,"metadata":{},"outputs":[{"data":{"text/markdown":["### Head Pattern Attribution Patching"],"text/plain":[""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n"," \n","\n"," \n","
\n"," \n"," \n"," "],"text/plain":[""]},"metadata":{},"output_type":"display_data"}],"source":["from typing_extensions import Literal\n","\n","\n","def stack_head_pattern_from_cache(\n"," cache,\n",") -> TT[\"layer_and_head_index\", \"batch\", \"dest_pos\", \"src_pos\"]:\n"," \"\"\"Stacks the head patterns from the cache into a single tensor.\"\"\"\n"," stacked_head_pattern = torch.stack(\n"," [cache[\"pattern\", l] for l in range(model.cfg.n_layers)], dim=0\n"," )\n"," stacked_head_pattern = einops.rearrange(\n"," stacked_head_pattern,\n"," \"layer batch head_index dest_pos src_pos -> (layer head_index) batch dest_pos src_pos\",\n"," )\n"," return stacked_head_pattern\n","\n","\n","def attr_patch_head_pattern(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"component\", \"dest_pos\", \"src_pos\"]:\n"," labels = HEAD_NAMES\n","\n"," clean_head_pattern = stack_head_pattern_from_cache(clean_cache)\n"," corrupted_head_pattern = stack_head_pattern_from_cache(corrupted_cache)\n"," corrupted_grad_head_pattern = stack_head_pattern_from_cache(corrupted_grad_cache)\n"," head_pattern_attr = einops.reduce(\n"," corrupted_grad_head_pattern * (clean_head_pattern - corrupted_head_pattern),\n"," \"component batch dest_pos src_pos -> component dest_pos src_pos\",\n"," \"sum\",\n"," )\n"," return head_pattern_attr, labels\n","\n","\n","head_pattern_attr, labels = attr_patch_head_pattern(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","\n","plot_attention_attr(\n"," einops.rearrange(\n"," head_pattern_attr,\n"," \"(layer head) dest src -> layer head dest src\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," ),\n"," clean_tokens,\n"," index=0,\n"," title=\"Head Pattern Attribution Patching\",\n",")"]},{"cell_type":"code","execution_count":22,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_head_vector_grad_input_from_grad_cache(\n"," grad_cache: ActivationCache, activation_name: Literal[\"q\", \"k\", \"v\"], layer: int\n",") -> TT[\"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," vector_grad = grad_cache[activation_name, layer]\n"," ln_scales = grad_cache[\"scale\", layer, \"ln1\"]\n"," attn_layer_object = model.blocks[layer].attn\n"," if activation_name == \"q\":\n"," W = attn_layer_object.W_Q\n"," elif activation_name == \"k\":\n"," W = attn_layer_object.W_K\n"," elif activation_name == \"v\":\n"," W = attn_layer_object.W_V\n"," else:\n"," raise ValueError(\"Invalid activation name\")\n","\n"," return einsum(\n"," \"batch pos head_index d_head, batch pos, head_index d_model d_head -> batch pos head_index d_model\",\n"," vector_grad,\n"," ln_scales.squeeze(-1),\n"," W,\n"," )\n","\n","\n","def get_stacked_head_vector_grad_input(\n"," grad_cache, activation_name: Literal[\"q\", \"k\", \"v\"]\n",") -> TT[\"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack(\n"," [\n"," get_head_vector_grad_input_from_grad_cache(grad_cache, activation_name, l)\n"," for l in range(model.cfg.n_layers)\n"," ],\n"," dim=0,\n"," )\n","\n","\n","def get_full_vector_grad_input(\n"," grad_cache,\n",") -> TT[\"qkv\", \"layer\", \"batch\", \"pos\", \"head_index\", \"d_model\"]:\n"," return torch.stack(\n"," [\n"," get_stacked_head_vector_grad_input(grad_cache, activation_name)\n"," for activation_name in [\"q\", \"k\", \"v\"]\n"," ],\n"," dim=0,\n"," )\n","\n","\n","def attr_patch_head_path(\n"," clean_cache: ActivationCache,\n"," corrupted_cache: ActivationCache,\n"," corrupted_grad_cache: ActivationCache,\n",") -> TT[\"qkv\", \"dest_component\", \"src_component\", \"pos\"]:\n"," \"\"\"\n"," Computes the attribution patch along the path between each pair of heads.\n","\n"," Sets this to zero for the path from any late head to any early head\n","\n"," \"\"\"\n"," start_labels = HEAD_NAMES\n"," end_labels = HEAD_NAMES_QKV\n"," full_vector_grad_input = get_full_vector_grad_input(corrupted_grad_cache)\n"," clean_head_result_stack = clean_cache.stack_head_results(-1)\n"," corrupted_head_result_stack = corrupted_cache.stack_head_results(-1)\n"," diff_head_result = einops.rearrange(\n"," clean_head_result_stack - corrupted_head_result_stack,\n"," \"(layer head_index) batch pos d_model -> layer batch pos head_index d_model\",\n"," layer=model.cfg.n_layers,\n"," head_index=model.cfg.n_heads,\n"," )\n"," path_attr = einsum(\n"," \"qkv layer_end batch pos head_end d_model, layer_start batch pos head_start d_model -> qkv layer_end head_end layer_start head_start pos\",\n"," full_vector_grad_input,\n"," diff_head_result,\n"," )\n"," correct_layer_order_mask = (\n"," torch.arange(model.cfg.n_layers)[None, :, None, None, None, None]\n"," > torch.arange(model.cfg.n_layers)[None, None, None, :, None, None]\n"," ).to(path_attr.device)\n"," zero = torch.zeros(1, device=path_attr.device)\n"," path_attr = torch.where(correct_layer_order_mask, path_attr, zero)\n","\n"," path_attr = einops.rearrange(\n"," path_attr,\n"," \"qkv layer_end head_end layer_start head_start pos -> (layer_end head_end qkv) (layer_start head_start) pos\",\n"," )\n"," return path_attr, end_labels, start_labels\n","\n","\n","head_path_attr, end_labels, start_labels = attr_patch_head_path(\n"," clean_cache, corrupted_cache, corrupted_grad_cache\n",")\n","imshow(\n"," head_path_attr.sum(-1),\n"," y=end_labels,\n"," yaxis=\"Path End (Head Input)\",\n"," x=start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=\"Head Path Attribution Patching\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" This is hard to parse. Here's an experiment with filtering for the most important heads and showing their paths."]},{"cell_type":"code","execution_count":23,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["head_out_values, head_out_indices = head_out_attr.sum(-1).abs().sort(descending=True)\n","line(head_out_values)\n","top_head_indices = head_out_indices[:22].sort().values\n","top_end_indices = []\n","top_end_labels = []\n","top_start_indices = []\n","top_start_labels = []\n","for i in top_head_indices:\n"," i = i.item()\n"," top_start_indices.append(i)\n"," top_start_labels.append(start_labels[i])\n"," for j in range(3):\n"," top_end_indices.append(3 * i + j)\n"," top_end_labels.append(end_labels[3 * i + j])\n","\n","imshow(\n"," head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n"," y=top_end_labels,\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=\"Head Path Attribution Patching (Filtered for Top Heads)\",\n",")"]},{"cell_type":"code","execution_count":24,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["for j, composition_type in enumerate([\"Query\", \"Key\", \"Value\"]):\n"," imshow(\n"," head_path_attr[top_end_indices, :][:, top_start_indices][j::3].sum(-1),\n"," y=top_end_labels[j::3],\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=f\"Head Path to {composition_type} Attribution Patching (Filtered for Top Heads)\",\n"," )"]},{"cell_type":"code","execution_count":25,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["top_head_path_attr = einops.rearrange(\n"," head_path_attr[top_end_indices, :][:, top_start_indices].sum(-1),\n"," \"(head_end qkv) head_start -> qkv head_end head_start\",\n"," qkv=3,\n",")\n","imshow(\n"," top_head_path_attr,\n"," y=[i[:-1] for i in top_end_labels[::3]],\n"," yaxis=\"Path End (Head Input)\",\n"," x=top_start_labels,\n"," xaxis=\"Path Start (Head Output)\",\n"," title=f\"Head Path Attribution Patching (Filtered for Top Heads)\",\n"," facet_col=0,\n"," facet_labels=[\"Query\", \"Key\", \"Value\"],\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" Let's now dive into 3 interesting heads: L5H5 (induction head), L8H6 (S-Inhibition Head), L9H9 (Name Mover) and look at their input and output paths (note - Q input means )"]},{"cell_type":"code","execution_count":26,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["interesting_heads = [\n"," 5 * model.cfg.n_heads + 5,\n"," 8 * model.cfg.n_heads + 6,\n"," 9 * model.cfg.n_heads + 9,\n","]\n","interesting_head_labels = [HEAD_NAMES[i] for i in interesting_heads]\n","for head_index, label in zip(interesting_heads, interesting_head_labels):\n"," in_paths = head_path_attr[3 * head_index : 3 * head_index + 3].sum(-1)\n"," out_paths = head_path_attr[:, head_index].sum(-1)\n"," out_paths = einops.rearrange(out_paths, \"(layer_head qkv) -> qkv layer_head\", qkv=3)\n"," all_paths = torch.cat([in_paths, out_paths], dim=0)\n"," all_paths = einops.rearrange(\n"," all_paths,\n"," \"path_type (layer head) -> path_type layer head\",\n"," layer=model.cfg.n_layers,\n"," head=model.cfg.n_heads,\n"," )\n"," imshow(\n"," all_paths,\n"," facet_col=0,\n"," facet_labels=[\n"," \"Query (In)\",\n"," \"Key (In)\",\n"," \"Value (In)\",\n"," \"Query (Out)\",\n"," \"Key (Out)\",\n"," \"Value (Out)\",\n"," ],\n"," title=f\"Input and Output Paths for head {label}\",\n"," yaxis=\"Layer\",\n"," xaxis=\"Head\",\n"," )"]},{"cell_type":"markdown","metadata":{},"source":[" ## Validating Attribution vs Activation Patching\n"," Let's now compare attribution and activation patching. Generally it's a decent approximation! The main place it fails is MLP0 and the residual stream\n"," My fuzzy intuition is that attribution patching works badly for \"big\" things which are poorly modelled as linear approximations, and works well for \"small\" things which are more like incremental changes. Anything involving replacing the embedding is a \"big\" thing, which includes residual streams, and in GPT-2 small MLP0 seems to be used as an \"extended embedding\" (where later layers use MLP0's output instead of the token embedding), so I also count it as big.\n"," See more discussion in the accompanying blog post!\n"]},{"cell_type":"markdown","metadata":{},"source":[" First do some refactoring to make attribution patching more generic. We make an attribution cache, which is an ActivationCache where each element is (clean_act - corrupted_act) * corrupted_grad, so that it's the per-element attribution for each activation. Thanks to linearity, we just compute things by adding stuff up along the relevant dimensions!"]},{"cell_type":"code","execution_count":27,"metadata":{},"outputs":[],"source":["attribution_cache_dict = {}\n","for key in corrupted_grad_cache.cache_dict.keys():\n"," attribution_cache_dict[key] = corrupted_grad_cache.cache_dict[key] * (\n"," clean_cache.cache_dict[key] - corrupted_cache.cache_dict[key]\n"," )\n","attr_cache = ActivationCache(attribution_cache_dict, model)"]},{"cell_type":"markdown","metadata":{},"source":[" By block: For each head we patch the starting residual stream, attention output + MLP output"]},{"cell_type":"code","execution_count":28,"metadata":{},"outputs":[],"source":["str_tokens = model.to_str_tokens(clean_tokens[0])\n","context_length = len(str_tokens)"]},{"cell_type":"code","execution_count":29,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"95a5290e11b64b6a95ef5dd37d027c7a","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/180 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_block_act_patch_result = patching.get_act_patch_block_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","imshow(\n"," every_block_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Activation Patching Per Block\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",")"]},{"cell_type":"code","execution_count":30,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_block_every(attr_cache):\n"," resid_pre_attr = einops.reduce(\n"," attr_cache.stack_activation(\"resid_pre\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," attn_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"attn_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n"," mlp_out_attr = einops.reduce(\n"," attr_cache.stack_activation(\"mlp_out\"),\n"," \"layer batch pos d_model -> layer pos\",\n"," \"sum\",\n"," )\n","\n"," every_block_attr_patch_result = torch.stack(\n"," [resid_pre_attr, attn_out_attr, mlp_out_attr], dim=0\n"," )\n"," return every_block_attr_patch_result\n","\n","\n","every_block_attr_patch_result = get_attr_patch_block_every(attr_cache)\n","imshow(\n"," every_block_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Attribution Patching Per Block\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n",")"]},{"cell_type":"code","execution_count":31,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_block_attr_patch_result.reshape(3, -1),\n"," x=every_block_act_patch_result.reshape(3, -1),\n"," facet_col=0,\n"," facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n"," title=\"Attribution vs Activation Patching Per Block\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," hover=[\n"," f\"Layer {l}, Position {p}, |{str_tokens[p]}|\"\n"," for l in range(model.cfg.n_layers)\n"," for p in range(context_length)\n"," ],\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers), \"layer -> (layer pos)\", pos=context_length\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow."]},{"cell_type":"code","execution_count":32,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"18b2e6b0985b40cd8c0cd1a16ba62975","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/144 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","imshow(\n"," every_head_all_pos_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Activation Patching Per Head (All Pos)\",\n"," xaxis=\"Head\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n",")"]},{"cell_type":"code","execution_count":33,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_all_pos_every(attr_cache):\n"," head_out_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_q_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_k_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_v_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_all_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack(\n"," [\n"," head_out_all_pos_attr,\n"," head_q_all_pos_attr,\n"," head_k_all_pos_attr,\n"," head_v_all_pos_attr,\n"," head_pattern_all_pos_attr,\n"," ]\n"," )\n","\n","\n","every_head_all_pos_attr_patch_result = get_attr_patch_attn_head_all_pos_every(\n"," attr_cache\n",")\n","imshow(\n"," every_head_all_pos_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution Patching Per Head (All Pos)\",\n"," xaxis=\"Head\",\n"," yaxis=\"Layer\",\n"," zmax=1,\n"," zmin=-1,\n",")"]},{"cell_type":"code","execution_count":34,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_head_all_pos_attr_patch_result.reshape(5, -1),\n"," x=every_head_all_pos_act_patch_result.reshape(5, -1),\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution vs Activation Patching Per Head (All Pos)\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," include_diag=True,\n"," hover=head_out_labels,\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers),\n"," \"layer -> (layer head)\",\n"," head=model.cfg.n_heads,\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" We see pretty good results in general, but significant errors for heads L5H5 on query and moderate errors for head L10H7 on query and key, and moderate errors for head L11H10 on key. But each of these is fine for pattern and output. My guess is that the problem is that these have pretty saturated attention on a single token, and the linear approximation is thus not great on the attention calculation here, but I'm not sure. When we plot the attention patterns, we do see this!\n"," Note that the axis labels are for the *first* prompt's tokens, but each facet is a different prompt, so this is somewhat inaccurate. In particular, every odd facet has indirect object and subject in the opposite order (IO first). But otherwise everything lines up between the prompts"]},{"cell_type":"code","execution_count":35,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"},{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["graph_tok_labels = [\n"," f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))\n","]\n","imshow(\n"," clean_cache[\"pattern\", 5][:, 5],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L5H5\",\n"," facet_name=\"Prompt\",\n",")\n","imshow(\n"," clean_cache[\"pattern\", 10][:, 7],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L10H7\",\n"," facet_name=\"Prompt\",\n",")\n","imshow(\n"," clean_cache[\"pattern\", 11][:, 10],\n"," x=graph_tok_labels,\n"," y=graph_tok_labels,\n"," facet_col=0,\n"," title=\"Attention for Head L11H10\",\n"," facet_name=\"Prompt\",\n",")\n","\n","\n","# [markdown]"]},{"cell_type":"code","execution_count":36,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"06f39489001845849fbc7446a07066f4","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/2160 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["every_head_by_pos_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(\n"," model, corrupted_tokens, clean_cache, ioi_metric\n",")\n","every_head_by_pos_act_patch_result = einops.rearrange(\n"," every_head_by_pos_act_patch_result,\n"," \"act_type layer pos head -> act_type (layer head) pos\",\n",")\n","imshow(\n"," every_head_by_pos_act_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Activation Patching Per Head (By Pos)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer & Head\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," y=head_out_labels,\n",")"]},{"cell_type":"code","execution_count":37,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def get_attr_patch_attn_head_by_pos_every(attr_cache):\n"," head_out_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"z\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_q_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"q\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_k_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"k\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_v_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"v\"),\n"," \"layer batch pos head_index d_head -> layer pos head_index\",\n"," \"sum\",\n"," )\n"," head_pattern_by_pos_attr = einops.reduce(\n"," attr_cache.stack_activation(\"pattern\"),\n"," \"layer batch head_index dest_pos src_pos -> layer dest_pos head_index\",\n"," \"sum\",\n"," )\n","\n"," return torch.stack(\n"," [\n"," head_out_by_pos_attr,\n"," head_q_by_pos_attr,\n"," head_k_by_pos_attr,\n"," head_v_by_pos_attr,\n"," head_pattern_by_pos_attr,\n"," ]\n"," )\n","\n","\n","every_head_by_pos_attr_patch_result = get_attr_patch_attn_head_by_pos_every(attr_cache)\n","every_head_by_pos_attr_patch_result = einops.rearrange(\n"," every_head_by_pos_attr_patch_result,\n"," \"act_type layer pos head -> act_type (layer head) pos\",\n",")\n","imshow(\n"," every_head_by_pos_attr_patch_result,\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution Patching Per Head (By Pos)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer & Head\",\n"," zmax=1,\n"," zmin=-1,\n"," x=[f\"{tok}_{i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],\n"," y=head_out_labels,\n",")"]},{"cell_type":"code","execution_count":38,"metadata":{},"outputs":[{"data":{"text/html":["\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["scatter(\n"," y=every_head_by_pos_attr_patch_result.reshape(5, -1),\n"," x=every_head_by_pos_act_patch_result.reshape(5, -1),\n"," facet_col=0,\n"," facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n"," title=\"Attribution vs Activation Patching Per Head (by Pos)\",\n"," xaxis=\"Activation Patch\",\n"," yaxis=\"Attribution Patch\",\n"," include_diag=True,\n"," hover=[f\"{label} {tok}\" for label in head_out_labels for tok in graph_tok_labels],\n"," color=einops.repeat(\n"," torch.arange(model.cfg.n_layers),\n"," \"layer -> (layer head pos)\",\n"," head=model.cfg.n_heads,\n"," pos=15,\n"," ),\n"," color_continuous_scale=\"Portland\",\n",")"]},{"cell_type":"markdown","metadata":{},"source":[" ## Factual Knowledge Patching Example\n"," Incomplete, but maybe of interest!\n"," Note that I have better results with the corrupted prompt as having random words rather than Colosseum."]},{"cell_type":"code","execution_count":39,"metadata":{},"outputs":[{"name":"stderr","output_type":"stream","text":["Using pad_token, but it is not set yet.\n"]},{"name":"stdout","output_type":"stream","text":["Loaded pretrained model gpt2-xl into HookedTransformer\n","Tokenized prompt: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Paris']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.73 Prob: 95.80% Token: | Paris|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.73\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m95.80\u001b[0m\u001b[1m% Token: | Paris|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.73 Prob: 95.80% Token: | Paris|\n","Top 1th token. Logit: 16.49 Prob: 1.39% Token: | E|\n","Top 2th token. Logit: 14.69 Prob: 0.23% Token: | the|\n","Top 3th token. Logit: 14.58 Prob: 0.21% Token: | É|\n","Top 4th token. Logit: 14.44 Prob: 0.18% Token: | France|\n","Top 5th token. Logit: 14.36 Prob: 0.16% Token: | Mont|\n","Top 6th token. Logit: 13.77 Prob: 0.09% Token: | Le|\n","Top 7th token. Logit: 13.66 Prob: 0.08% Token: | Ang|\n","Top 8th token. Logit: 13.43 Prob: 0.06% Token: | V|\n","Top 9th token. Logit: 13.42 Prob: 0.06% Token: | Stras|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Paris', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Paris'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Tokenized prompt: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n","Tokenized answer: [' Rome']\n"]},{"data":{"text/html":["
Performance on answer token:\n","Rank: 0        Logit: 20.02 Prob: 83.70% Token: | Rome|\n","
\n"],"text/plain":["Performance on answer token:\n","\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m20.02\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m83.70\u001b[0m\u001b[1m% Token: | Rome|\u001b[0m\n"]},"metadata":{},"output_type":"display_data"},{"name":"stdout","output_type":"stream","text":["Top 0th token. Logit: 20.02 Prob: 83.70% Token: | Rome|\n","Top 1th token. Logit: 17.03 Prob: 4.23% Token: | Naples|\n","Top 2th token. Logit: 16.85 Prob: 3.51% Token: | Pompe|\n","Top 3th token. Logit: 16.14 Prob: 1.73% Token: | Ver|\n","Top 4th token. Logit: 15.87 Prob: 1.32% Token: | Florence|\n","Top 5th token. Logit: 14.77 Prob: 0.44% Token: | Roma|\n","Top 6th token. Logit: 14.68 Prob: 0.40% Token: | Milan|\n","Top 7th token. Logit: 14.66 Prob: 0.39% Token: | ancient|\n","Top 8th token. Logit: 14.37 Prob: 0.29% Token: | Pal|\n","Top 9th token. Logit: 14.30 Prob: 0.27% Token: | Constantinople|\n"]},{"data":{"text/html":["
Ranks of the answer tokens: [(' Rome', 0)]\n","
\n"],"text/plain":["\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Rome'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n"]},"metadata":{},"output_type":"display_data"}],"source":["gpt2_xl = HookedTransformer.from_pretrained(\"gpt2-xl\")\n","clean_prompt = \"The Eiffel Tower is located in the city of\"\n","clean_answer = \" Paris\"\n","# corrupted_prompt = \"The red brown fox jumps is located in the city of\"\n","corrupted_prompt = \"The Colosseum is located in the city of\"\n","corrupted_answer = \" Rome\"\n","utils.test_prompt(clean_prompt, clean_answer, gpt2_xl)\n","utils.test_prompt(corrupted_prompt, corrupted_answer, gpt2_xl)"]},{"cell_type":"code","execution_count":40,"metadata":{},"outputs":[],"source":["clean_answer_index = gpt2_xl.to_single_token(clean_answer)\n","corrupted_answer_index = gpt2_xl.to_single_token(corrupted_answer)\n","\n","\n","def factual_logit_diff(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return logits[0, -1, clean_answer_index] - logits[0, -1, corrupted_answer_index]"]},{"cell_type":"code","execution_count":41,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean logit diff: 10.634519577026367\n","Corrupted logit diff: -8.988396644592285\n","Clean Metric: tensor(1., device='cuda:0', grad_fn=)\n","Corrupted Metric: tensor(0., device='cuda:0', grad_fn=)\n"]}],"source":["clean_logits, clean_cache = gpt2_xl.run_with_cache(clean_prompt)\n","CLEAN_LOGIT_DIFF_FACTUAL = factual_logit_diff(clean_logits).item()\n","corrupted_logits, _ = gpt2_xl.run_with_cache(corrupted_prompt)\n","CORRUPTED_LOGIT_DIFF_FACTUAL = factual_logit_diff(corrupted_logits).item()\n","\n","\n","def factual_metric(logits: TT[\"batch\", \"position\", \"d_vocab\"]):\n"," return (factual_logit_diff(logits) - CORRUPTED_LOGIT_DIFF_FACTUAL) / (\n"," CLEAN_LOGIT_DIFF_FACTUAL - CORRUPTED_LOGIT_DIFF_FACTUAL\n"," )\n","\n","\n","print(\"Clean logit diff:\", CLEAN_LOGIT_DIFF_FACTUAL)\n","print(\"Corrupted logit diff:\", CORRUPTED_LOGIT_DIFF_FACTUAL)\n","print(\"Clean Metric:\", factual_metric(clean_logits))\n","print(\"Corrupted Metric:\", factual_metric(corrupted_logits))"]},{"cell_type":"code","execution_count":42,"metadata":{},"outputs":[],"source":["# corrupted_value, corrupted_cache, corrupted_grad_cache = get_cache_fwd_and_bwd(gpt2_xl, corrupted_prompt, factual_metric)"]},{"cell_type":"code","execution_count":43,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["Clean: ['<|endoftext|>', 'The', ' E', 'iff', 'el', ' Tower', ' is', ' located', ' in', ' the', ' city', ' of']\n","Corrupted: ['<|endoftext|>', 'The', ' Col', 'os', 'se', 'um', ' is', ' located', ' in', ' the', ' city', ' of']\n"]}],"source":["clean_tokens = gpt2_xl.to_tokens(clean_prompt)\n","clean_str_tokens = gpt2_xl.to_str_tokens(clean_prompt)\n","corrupted_tokens = gpt2_xl.to_tokens(corrupted_prompt)\n","corrupted_str_tokens = gpt2_xl.to_str_tokens(corrupted_prompt)\n","print(\"Clean:\", clean_str_tokens)\n","print(\"Corrupted:\", corrupted_str_tokens)"]},{"cell_type":"code","execution_count":44,"metadata":{},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"b767eef7a3cd49b9b3cb6e5301463f08","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/48 [00:00\n","\n","\n","
\n","
\n","\n",""]},"metadata":{},"output_type":"display_data"}],"source":["def act_patch_residual(clean_cache, corrupted_tokens, model: HookedTransformer, metric):\n"," if len(corrupted_tokens.shape) == 2:\n"," corrupted_tokens = corrupted_tokens[0]\n"," residual_patches = torch.zeros(\n"," (model.cfg.n_layers, len(corrupted_tokens)), device=model.cfg.device\n"," )\n","\n"," def residual_hook(resid_pre, hook, layer, pos):\n"," resid_pre[:, pos, :] = clean_cache[\"resid_pre\", layer][:, pos, :]\n"," return resid_pre\n","\n"," for layer in tqdm.tqdm(range(model.cfg.n_layers)):\n"," for pos in range(len(corrupted_tokens)):\n"," patched_logits = model.run_with_hooks(\n"," corrupted_tokens,\n"," fwd_hooks=[\n"," (\n"," f\"blocks.{layer}.hook_resid_pre\",\n"," partial(residual_hook, layer=layer, pos=pos),\n"," )\n"," ],\n"," )\n"," residual_patches[layer, pos] = metric(patched_logits).item()\n"," return residual_patches\n","\n","\n","residual_act_patch = act_patch_residual(\n"," clean_cache, corrupted_tokens, gpt2_xl, factual_metric\n",")\n","\n","imshow(\n"," residual_act_patch,\n"," title=\"Factual Recall Patching (Residual)\",\n"," xaxis=\"Position\",\n"," yaxis=\"Layer\",\n"," x=clean_str_tokens,\n",")"]}],"metadata":{"kernelspec":{"display_name":"base","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.11.8"},"orig_nbformat":4,"vscode":{"interpreter":{"hash":"d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"}}},"nbformat":4,"nbformat_minor":2} From 16ee6b087909736c4460dc89cc9fe5e1a6164da6 Mon Sep 17 00:00:00 2001 From: Anthony Duong Date: Mon, 3 Jun 2024 22:45:47 -0700 Subject: [PATCH 3/4] fixes in cell outputs --- demos/Head_Detector_Demo.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/demos/Head_Detector_Demo.ipynb b/demos/Head_Detector_Demo.ipynb index af55fcaeb..224eb1a35 100644 --- a/demos/Head_Detector_Demo.ipynb +++ b/demos/Head_Detector_Demo.ipynb @@ -145,9 +145,9 @@ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10->transformer-lens==0.0.0) (1.3.0)\n", "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens==0.0.0) (5.0.0)\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting git+https://github.com/TransformerLensOrg/neel-plotly.git\n", + "Collecting git+https://github.com/neelnanda-io/neel-plotly.git\n", " Cloning https://github.com/TransformerLensOrg/neel-plotly.git to /tmp/pip-req-build-u8mujxc3\n", - " Running command git clone --filter=blob:none --quiet https://github.com/TransformerLensOrg/neel-plotly.git /tmp/pip-req-build-u8mujxc3\n", + " Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-u8mujxc3\n", " Resolved https://github.com/TransformerLensOrg/neel-plotly.git to commit 6dc096fdc575da978d3e56489f2347d95cd397e7\n", " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (from neel-plotly==0.0.0) (0.6.1)\n", From 55cca8bb4cb7dde74aaf975488eb3064b1cf919d Mon Sep 17 00:00:00 2001 From: Anthony Duong Date: Mon, 3 Jun 2024 22:50:59 -0700 Subject: [PATCH 4/4] fixes more in Head_Detector_Demo.ipynb --- demos/Head_Detector_Demo.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/demos/Head_Detector_Demo.ipynb b/demos/Head_Detector_Demo.ipynb index 224eb1a35..33c9b09d8 100644 --- a/demos/Head_Detector_Demo.ipynb +++ b/demos/Head_Detector_Demo.ipynb @@ -146,9 +146,9 @@ "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer-lens==0.0.0) (5.0.0)\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Collecting git+https://github.com/neelnanda-io/neel-plotly.git\n", - " Cloning https://github.com/TransformerLensOrg/neel-plotly.git to /tmp/pip-req-build-u8mujxc3\n", + " Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-u8mujxc3\n", " Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-u8mujxc3\n", - " Resolved https://github.com/TransformerLensOrg/neel-plotly.git to commit 6dc096fdc575da978d3e56489f2347d95cd397e7\n", + " Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc096fdc575da978d3e56489f2347d95cd397e7\n", " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "Requirement already satisfied: einops in /usr/local/lib/python3.10/dist-packages (from neel-plotly==0.0.0) (0.6.1)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from neel-plotly==0.0.0) (1.24.3)\n",