Skip to content

Commit

Permalink
Merge pull request #65 from stanfordnlp/zen/causaltracingfull
Browse files Browse the repository at this point in the history
[Minor] incidental changes to the tutorials for updated results
  • Loading branch information
frankaging authored Jan 18, 2024
2 parents 3d4fb2b + 428921d commit cde9f98
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 178 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class ModelWithIntervenables(nn.Module):
| Beginner | [**Getting Started**](tutorials/basic_tutorials/Basic_Intervention.ipynb) | [<img align="center" src="https://colab.research.google.com/assets/colab-badge.svg" />](https://colab.research.google.com/github/stanfordnlp/pyvene/blob/main/tutorials/basic_tutorials/Basic_Intervention.ipynb) | Introduces basic static intervention on factual recall examples |
| Beginner | [**Intervened Model Generation**](tutorials/advanced_tutorials/Intervened_Model_Generation.ipynb) | [<img align="center" src="https://colab.research.google.com/assets/colab-badge.svg" />](https://colab.research.google.com/github/stanfordnlp/pyvene/blob/main/tutorials/advanced_tutorials/Intervened_Model_Generation.ipynb) | Shows how to intervene a model during generation |
| Intermediate | [**Intervene Your Local Models**](tutorials/basic_tutorials/Add_New_Model_Type.ipynb) | [<img align="center" src="https://colab.research.google.com/assets/colab-badge.svg" />](https://colab.research.google.com/github/stanfordnlp/pyvene/blob/main/tutorials/basic_tutorials/Add_New_Model_Type.ipynb) | Illustrates how to run this library with your own models |
| Intermediate | [**ROME Causal Tracing**](tutorials/advanced_tutorials/Causal_Tracing.ipynb) | [<img align="center" src="https://colab.research.google.com/assets/colab-badge.svg" />](https://colab.research.google.com/github/stanfordnlp/pyvene/blob/main/tutorials/advanced_tutorials/Causal_Tracing.ipynb) | Reproduce ROME's Results on Factual Associations with GPT2-XL |
| Intermediate | [**Intervention v.s. Probing**](tutorials/advanced_tutorials/Probing_Gender.ipynb) | [<img align="center" src="https://colab.research.google.com/assets/colab-badge.svg" />](https://colab.research.google.com/github/stanfordnlp/pyvene/blob/main/tutorials/advanced_tutorials/Probing_Gender.ipynb) | Illustrates how to run trainable interventions and probing with pythia-6.9B |
| Advanced | [**Trainable Interventions for Causal Abstraction**](tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb) | [<img align="center" src="https://colab.research.google.com/assets/colab-badge.svg" />](https://colab.research.google.com/github/stanfordnlp/pyvene/blob/main/tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb) | Illustrates how to train an intervention to discover causal mechanisms of a neural model |

## Causal Abstraction: From Interventions to Gain Interpretability Insights
Expand Down
64 changes: 27 additions & 37 deletions tutorials/advanced_tutorials/Causal_Tracing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"source": [
"## Tutorial of Causal Tracing\n",
"\n",
"Causal tracing was a methodology for locating where facts are stored in transformer LMs, introduced in the paper [\"Locating and Editing Factual Associations in GPT\" (Meng et al., 2023)](https://arxiv.org/abs/2202.05262). In this notebook, we will implement their method using this library and replicate the first causal tracing example in the paper (figure 1e on page 2)."
"Causal tracing was a methodology for locating where facts are stored in transformer LMs, introduced in the paper [\"Locating and Editing Factual Associations in GPT\" (Meng et al., 2023)](https://arxiv.org/abs/2202.05262). In this notebook, we will implement their method using this library and replicate the first causal tracing example in the paper (full figure 1 on page 2)."
]
},
{
Expand All @@ -33,6 +33,21 @@
"### Set-up"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" # This library is our indicator that the required installs\n",
" # need to be done.\n",
" import pyvene\n",
"\n",
"except ModuleNotFoundError:\n",
" !pip install git+https://github.com/frankaging/pyvene.git"
]
},
{
"cell_type": "code",
"execution_count": 5,
Expand Down Expand Up @@ -94,7 +109,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 38,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -149,7 +164,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -194,7 +209,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 30,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -240,26 +255,10 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"def restore_corrupted_config(layer, stream=\"block_output\"):\n",
" intervenable_config = IntervenableConfig(\n",
" intervenable_representations=[\n",
" IntervenableRepresentationConfig(\n",
" 0, # layer\n",
" \"block_input\", # intervention type\n",
" ),\n",
" IntervenableRepresentationConfig(\n",
" layer, # layer\n",
" stream, # intervention type\n",
" ),\n",
" ],\n",
" intervenable_interventions_type=[NoiseIntervention, VanillaIntervention],\n",
" )\n",
" return intervenable_config\n",
"\n",
"def restore_corrupted_with_interval_config(\n",
" layer, stream=\"mlp_activation\", window=10, num_layers=48):\n",
" start = max(0, layer - window // 2)\n",
Expand Down Expand Up @@ -290,7 +289,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 32,
"metadata": {},
"outputs": [
{
Expand All @@ -309,20 +308,11 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 48/48 [00:33<00:00, 1.41it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 48/48 [01:44<00:00, 2.18s/it]\n"
]
}
],
"outputs": [],
"source": [
"for stream in [\"block_output\", \"mlp_activation\"]:\n",
"for stream in [\"block_output\", \"mlp_activation\", \"attention_output\"]:\n",
" data = []\n",
" for layer_i in tqdm(range(gpt.config.n_layer)):\n",
" for pos_i in range(7):\n",
Expand All @@ -334,10 +324,10 @@
" intervenable = IntervenableModel(intervenable_config, gpt)\n",
" _, counterfactual_outputs = intervenable(\n",
" base,\n",
" [base] + [base]*n_restores,\n",
" [None] + [base]*n_restores,\n",
" {\n",
" \"sources->base\": (\n",
" [[[0, 1, 2, 3]]] + [[[pos_i]]]*n_restores,\n",
" [None] + [[[pos_i]]]*n_restores,\n",
" [[[0, 1, 2, 3]]] + [[[pos_i]]]*n_restores,\n",
" )\n",
" },\n",
Expand All @@ -360,7 +350,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 37,
"metadata": {},
"outputs": [
{
Expand Down
240 changes: 99 additions & 141 deletions tutorials/advanced_tutorials/Probing_Gender.ipynb

Large diffs are not rendered by default.

0 comments on commit cde9f98

Please sign in to comment.