Skip to content

Commit

Permalink
minor fix on requirements, start cleaning up the tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Jan 11, 2024
1 parent 8e6b310 commit 561e71b
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 113 deletions.
83 changes: 29 additions & 54 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,24 @@
To interpret causal mechanisms of neural networks with their internals, we introduce **pyvene**, an open-source and intervention-oriented Python library that supports customizable interventions on different families of neural architectures (e.g., RNN or Transformers). The basic operation is an in-place activation modification during the computation flow of a neural model. It supports complex intervention schemas (e.g., parallel or serialized interventions) and a wide range of intervention modes (e.g., static or trained interventions) to enable practitioners to quantify counterfactual behaviors at scale to gain interpretability insights. We showcase **pyvene** out-of-box supports a wide range of intervention-based interpretability methods such as causal abstraction, circuit finding, and knowledge localization. **pyvene** provides a unified and extensible framework to perform interventions on neural models, and to share interventions with others.


## Interventions v.s. Alignments with Model Internals
In this section, we discuss topics from interventions to alignments with model internals.
## Installation
Install this package directly from the source code as,
```bash
!pip install git+https://github.com/frankaging/pyvene.git
```

### Interventions
Intervention is the basic unit of this library. It means manipulating the model's activations, without any assumption of how the model behavior will change. We can zero-out a set of neurons, or swap activations between examples (i.e., interchange interventions). Here, we show how we can intervene in model internals with this library.
## Basic Interventions
You can intervene with supported models as,
```python
# helper functions to load gpt2 from huggingface
from pyvene.models.gpt2.modelings_intervenable_gpt2 import create_gpt2

#### Loading models from HuggingFace
```py
from models.utils import create_gpt2
from pyvene import IntervenableRepresentationConfig, IntervenableConfig, IntervenableModel
from pyvene import VanillaIntervention

config, tokenizer, gpt = create_gpt2()
```
config, tokenizer, gpt2 = create_gpt2()

#### Create a simple intervenable config
```py
intervenable_config = IntervenableConfig(
intervenable_model_type="gpt2",
intervenable_representations=[
IntervenableRepresentationConfig(
0, # intervening layer 0
Expand All @@ -42,28 +43,32 @@ intervenable_config = IntervenableConfig(
),
],
)
```

#### Turn the model into an intervenable object
The basic idea is to consider the intervenable model as a regular HuggingFace model, except that it supports an intervenable forward function.
```py
intervenable_gpt = IntervenableModel(intervenable_config, gpt)
```
intervenable_gpt2 = IntervenableModel(intervenable_config, gpt2)

#### Intervene by swapping activations between examples
```py
base = tokenizer("The capital of Spain is", return_tensors="pt")
sources = [tokenizer("The capital of Italy is", return_tensors="pt")]

_, counterfactual_outputs = intervenable_gpt(
original_outputs, intervened_outputs = intervenable_gpt2(
base,
sources,
{"sources->base": ([[[4]]], [[[4]]])} # intervene base with sources
{"sources->base": ([[[4]]], [[[4]]])} # intervene base with sources on the fourth token.
)
original_outputs.last_hidden_state - intervened_outputs.last_hidden_state
```
---
which returns,

```
tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0008, -0.0078, -0.0066, ..., 0.0007, -0.0018, 0.0060]]])
```
showing that we have causal effects only on the last token as expected.


### Alignments
## From Interventions to Gain Interpretability Insights
If the model responds systematically to your interventions, then you start to associate certain regions in the network with a high-level concept. This is an alignment. Here is a more concrete example,
```py
def add_three_numbers(a, b, c):
Expand Down Expand Up @@ -119,25 +124,6 @@ intervenable.train(
where you need to pass in a trainable dataset, and your customized loss and metrics function. The trainable interventions can later be saved on to your disk. You can also use `intervenable.evaluate()` your interventions in terms of customized objectives.


## Tutorials
We released [a set of tutorials](https://github.com/frankaging/pyvene/tree/main/tutorials) for doing model interventions and model alignments. Here are some of them,

### `Basic_Intervention.ipynb`
(**Intervention Tutorial**) This is a tutorial for doing simple path patching as in **Path Patching**[^pp], **Causal Scrubbing**[^cs]. Thanks to [Aryaman Arora](https://aryaman.io/). This is a set of experiments trying to reproduce some of the experiments in his awesome [nano-causal-interventions](https://github.com/aryamanarora/nano-causal-interventions) repository.

### `Intervened_Model_Generation.ipynb`
(**Intervention Tutorial**) This is a tutorial on how to intervene the TinyStories-33M model to change its story generation, with sad endings and happy endings. Different from other tutorials, this is a multi-token language generation, closer to other real-world use cases.

### `Intervention_Training.ipynb`
(**Alignment Tutorial**) This is a tutorial covering the basics of how to train an intervention to find alignments with a gpt2 model finetuned on a logical reasoning task.

### `DAS_with_IOI.ipynb`
(**Alignment Tutorial**) This is a tutorial reproducing key components (i.e., name mover heads, name position information) for the indirect object identification (IOI) circuit introduced by Wang et al. (2023).

### `NonTransformer_MLP_Intervention.ipynb` and `NonTransformer_GRU_Intervention.ipynb`
(**Intervention Tutorial**) These are tutorials for non-Transformer models such as MLPs and GRUs.


## Unit-tests
When adding new methods or APIs, unit tests are now enforced. To run existing tests, you can kick off the python unittest command in the discovery mode as,
```bash
Expand All @@ -147,13 +133,6 @@ python -m unittest discover -p '*TestCase.py'
When checking in new code, please also consider to add new tests in the same PR. Please include test results in the PR to make sure all the existing test cases are passing. Please see the `qa_runbook.ipynb` notebook about a set of conventions about how to add test cases. The code coverage for this repository is currently `low`, and we are adding more automated tests.


## System Requirements
- Python 3.8 is supported.
- Pytorch Version: >= 2.0
- Transformers ToT is recommended
- Datasets Version ToT is recommended


## Related Works in Discovering Causal Mechanism of LLMs
If you would like to read more works on this area, here is a list of papers that try to align or discover the causal mechanisms of LLMs.
- [Causal Abstractions of Neural Networks](https://arxiv.org/abs/2106.02997): This paper introduces interchange intervention (a.k.a. activation patching or causal scrubbing). It tries to align a causal model with the model's representations.
Expand Down Expand Up @@ -181,7 +160,3 @@ If you use this repository, please consider to cite relevant papers:
booktitle={NeurIPS}
}
```

[^pp]: [Wang et al. (2022)](https://arxiv.org/abs/2211.00593), [Goldowsky-Dill et al. (2023)](https://arxiv.org/abs/2304.05969)
[^cs]: [Chan et al. (2022)](https://www.lesswrong.com/s/h95ayYYwMebGEYN5y)
[^ii]: [Geiger et al. (2021a)](https://arxiv.org/abs/2106.02997), [Geiger et al. (2021b)](https://arxiv.org/abs/2112.00826), [Geiger et al. (2023)](https://arxiv.org/abs/2301.04709), [Wu et al. (2023)](https://arxiv.org/pdf/2303.02536)
2 changes: 1 addition & 1 deletion pyvene/models/configuration_intervenable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
class IntervenableConfig(PretrainedConfig):
def __init__(
self,
intervenable_model_type="gpt2",
intervenable_model_type=None,
intervenable_representations=[IntervenableRepresentationConfig()],
intervenable_interventions_type=VanillaIntervention,
mode="parallel",
Expand Down
1 change: 1 addition & 0 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, intervenable_config, model, **kwargs):
self.mode = intervenable_config.mode
intervention_type = intervenable_config.intervenable_interventions_type
self.is_model_stateless = is_stateless(model)
self.intervenable_config.intervenable_model_type = type(model) # backfill
self.use_fast = kwargs["use_fast"] if "use_fast" in kwargs else False
if self.use_fast:
logging.warn(
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ torch
transformers
datasets
notebook
sentencepiece
protobuf==3.20.0
protobuf==3.19.1
matplotlib
wandb
seaborn
Expand Down
53 changes: 22 additions & 31 deletions tutorials/basic_tutorials/Add_Activations_to_Streams.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,44 +50,43 @@
"execution_count": 2,
"id": "c34ae314",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2024-01-11 00:31:07,569] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
]
}
],
"source": [
"try:\n",
" # This library is our indicator that the required installs\n",
" # need to be done.\n",
" import transformers\n",
" import sys\n",
" import pyvene\n",
"\n",
" sys.path.append(\"pyvene/\")\n",
"except ModuleNotFoundError:\n",
" !git clone https://github.com/frankaging/pyvene.git\n",
" !pip install -r pyvene/requirements.txt\n",
" import sys\n",
"\n",
" sys.path.append(\"pyvene/\")"
" !pip install git+https://github.com/frankaging/pyvene.git"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "e5cb9eb2",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"sys.path.append(\"../..\")\n",
"\n",
"import torch\n",
"import pandas as pd\n",
"from models.basic_utils import embed_to_distrib, top_vals, format_token\n",
"from models.configuration_intervenable_model import (\n",
"from pyvene.models.basic_utils import embed_to_distrib, top_vals, format_token\n",
"from pyvene import (\n",
" IntervenableModel,\n",
" AdditionIntervention,\n",
" SubtractionIntervention,\n",
" IntervenableRepresentationConfig,\n",
" IntervenableConfig,\n",
")\n",
"from models.intervenable_base import IntervenableModel\n",
"from models.interventions import AdditionIntervention, SubtractionIntervention\n",
"from models.gpt2.modelings_intervenable_gpt2 import create_gpt2\n",
"from pyvene.models.gpt2.modelings_intervenable_gpt2 import create_gpt2\n",
"\n",
"%config InlineBackend.figure_formats = ['svg']\n",
"from plotnine import (\n",
Expand All @@ -113,7 +112,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "1fc15f36",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -148,7 +147,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "151ded21",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -214,7 +213,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "0481a874",
"metadata": {},
"outputs": [
Expand All @@ -235,7 +234,7 @@
" 'layer.11.repr.mlp_output.unit.pos.nunit.1#0']"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -445,14 +444,6 @@
")\n",
"print(g)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbc92072",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
58 changes: 33 additions & 25 deletions tutorials/basic_tutorials/Basic_Intervention.ipynb

Large diffs are not rendered by default.

0 comments on commit 561e71b

Please sign in to comment.