Skip to content

Commit

Permalink
[Minor] Adding interchange intervention for SAEs
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Sep 24, 2024
1 parent 76bb067 commit 21d0c7b
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 81 deletions.
2 changes: 2 additions & 0 deletions pyvene/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@
from .models.interventions import NoiseIntervention
from .models.interventions import SigmoidMaskIntervention
from .models.interventions import AutoencoderIntervention
from .models.interventions import JumpReLUAutoencoderIntervention
from .models.interventions import InterventionOutput


# Utils
from .models.basic_utils import *
from .models.intervention_utils import _do_intervention_by_swap
from .models.intervenable_modelcard import type_to_module_mapping, type_to_dimension_mapping
from .models.gpt2.modelings_intervenable_gpt2 import create_gpt2
from .models.gpt2.modelings_intervenable_gpt2 import create_gpt2_lm
Expand Down
44 changes: 44 additions & 0 deletions pyvene/models/interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,47 @@ def forward(self, base, source, subspaces=None):

def __str__(self):
return f"AutoencoderIntervention()"


class JumpReLUAutoencoderIntervention(TrainableIntervention):
"""Interchange intervention on JumpReLU SAE's latent subspaces"""
def __init__(self, **kwargs):
# Note that we initialise these to zeros because we're loading in pre-trained weights.
# If you want to train your own SAEs then we recommend using blah
super().__init__(**kwargs, keep_last_dim=True)
self.W_enc = torch.nn.Parameter(torch.zeros(self.embed_dim, kwargs["low_rank_dimension"]))
self.W_dec = torch.nn.Parameter(torch.zeros(kwargs["low_rank_dimension"], self.embed_dim))
self.threshold = torch.nn.Parameter(torch.zeros(kwargs["low_rank_dimension"]))
self.b_enc = torch.nn.Parameter(torch.zeros(kwargs["low_rank_dimension"]))
self.b_dec = torch.nn.Parameter(torch.zeros(self.embed_dim))

def encode(self, input_acts):
pre_acts = input_acts @ self.W_enc + self.b_enc
mask = (pre_acts > self.threshold)
acts = mask * torch.nn.functional.relu(pre_acts)
return acts

def decode(self, acts):
return acts @ self.W_dec + self.b_dec

def forward(self, base, source=None, subspaces=None):
# generate latents for base and source runs.
base_latent = self.encode(base)
source_latent = self.encode(source)
# intervention.
intervened_latent = _do_intervention_by_swap(
base_latent,
source_latent,
"interchange",
self.interchange_dim,
subspaces,
subspace_partition=self.subspace_partition,
use_fast=self.use_fast,
)
# decode intervened latent.
recon = self.decode(intervened_latent)
return recon

def __str__(self):
return f"JumpReLUAutoencoderIntervention()"

192 changes: 111 additions & 81 deletions tutorials/basic_tutorials/Sparse_Autoencoder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
},
{
"cell_type": "code",
"execution_count": 68,
"execution_count": 1,
"id": "e5d14f0a-b02d-4d1f-863a-dbb1e475e264",
"metadata": {},
"outputs": [],
"source": [
"__author__ = \"Zhengxuan Wu\"\n",
"__version__ = \"08/07/2024\""
"__version__ = \"09/23/2024\""
]
},
{
Expand All @@ -50,7 +50,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"id": "dd197c1f-71b5-4379-a9dd-2f6ff27083f6",
"metadata": {},
"outputs": [
Expand All @@ -75,7 +75,7 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 4,
"id": "209bfc46-7685-4e66-975f-3280ed516b52",
"metadata": {},
"outputs": [],
Expand All @@ -85,7 +85,8 @@
" SourcelessIntervention,\n",
" TrainableIntervention,\n",
" DistributedRepresentationIntervention,\n",
" CollectIntervention\n",
" CollectIntervention,\n",
" JumpReLUAutoencoderIntervention\n",
")\n",
"\n",
"from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer\n",
Expand All @@ -108,28 +109,14 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"id": "a6e7e7fb-5e73-4711-b378-bc1b04ab1e7f",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7e4aaac37998428dbe22cc95595c3fcc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors.index.json: 0%| | 0.00/24.2k [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fa51ca858082486bb23796d8146aadea",
"model_id": "192a06afdbdc4c868bc6d20677b3dd38",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -143,49 +130,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "77cf508dd4bc42e19452855fdeb8744b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00001-of-00003.safetensors: 0%| | 0.00/4.99G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2913e6370a364fe6823d35b31ccbd9e4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00002-of-00003.safetensors: 0%| | 0.00/4.98G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "18e7e07b13394f58bab2b5c4c6b32394",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00003-of-00003.safetensors: 0%| | 0.00/481M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6defff33e9d74d8b9f754d20092dd7b4",
"model_id": "3e088f2b3808489bac70b9aeb5ae73f0",
"version_major": 2,
"version_minor": 0
},
Expand All @@ -195,20 +140,6 @@
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a5981f2c5b6f4e41acd7a7ad5a574557",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"generation_config.json: 0%| | 0.00/168 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
Expand Down Expand Up @@ -276,7 +207,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 7,
"id": "d490a50c-a1cd-4def-90c2-cd6bfe67266f",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -311,6 +242,7 @@
"class JumpReLUSAECollectIntervention(\n",
" CollectIntervention\n",
"):\n",
" \"\"\"Collect activations\"\"\"\n",
" def __init__(self, **kwargs):\n",
" # Note that we initialise these to zeros because we're loading in pre-trained weights.\n",
" # If you want to train your own SAEs then we recommend using blah\n",
Expand Down Expand Up @@ -500,7 +432,7 @@
"metadata": {},
"outputs": [],
"source": [
"class JumpReLUSAEIntervention(\n",
"class JumpReLUSAESteeringIntervention(\n",
" SourcelessIntervention,\n",
" TrainableIntervention, \n",
" DistributedRepresentationIntervention\n",
Expand Down Expand Up @@ -544,7 +476,7 @@
"metadata": {},
"outputs": [],
"source": [
"sae = JumpReLUSAEIntervention(\n",
"sae = JumpReLUSAESteeringIntervention(\n",
" embed_dim=params['W_enc'].shape[0],\n",
" low_rank_dimension=params['W_enc'].shape[1]\n",
")\n",
Expand Down Expand Up @@ -614,6 +546,104 @@
"source": [
"**Here you go: a \"Space-travel, time-travel\" Doodle!**"
]
},
{
"cell_type": "markdown",
"id": "22cabe19-2c2f-46c7-a631-d0b40fca5308",
"metadata": {},
"source": [
"### Interchange intervention with JumpReLU SAEs.\n",
"\n",
"You can also swap values between examples for a specific latent dimension. However, since SAE usually maps a concpet to 1D subspace, swapping between examples and resetting the scalar to another value are similar.\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4f23b199-ca01-4676-9a2d-61b24b96dc2f",
"metadata": {},
"outputs": [],
"source": [
"sae = JumpReLUAutoencoderIntervention(\n",
" embed_dim=params['W_enc'].shape[0],\n",
" low_rank_dimension=params['W_enc'].shape[1]\n",
")\n",
"sae.load_state_dict(pt_params, strict=False)\n",
"sae.cuda()\n",
"\n",
"# add the intervention to the model computation graph via the config\n",
"pv_model = pyvene.IntervenableModel({\n",
" \"component\": f\"model.layers[{LAYER}].output\",\n",
" \"intervention\": sae}, model=model)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "9dbe3883-3588-45fe-91bf-aeb075dea642",
"metadata": {},
"outputs": [],
"source": [
"base = tokenizer(\n",
" \"Which dog breed do people think is cuter, poodle or doodle?\", \n",
" return_tensors=\"pt\").to(\"cuda\")\n",
"source = tokenizer(\n",
" \"Origin (general) Space-travel, time-travel\", \n",
" return_tensors=\"pt\").to(\"cuda\")\n",
"\n",
"# run an interchange intervention \n",
"original_outputs, intervened_outputs = pv_model(\n",
" # the base input\n",
" base=base, \n",
" # the source input\n",
" sources=source, \n",
" # the location to intervene (swap last tokens)\n",
" unit_locations={\"sources->base\": (11, 14)},\n",
" # the SAE latent dimension mapping to the time travel concept (\"10004\")\n",
" subspaces=[10004],\n",
" output_original_output=True\n",
")\n",
"logits_diff = intervened_outputs.logits[:,-1] - original_outputs.logits[:,-1]\n",
"values, indices = logits_diff.topk(k=10, sorted=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "57b8c19a-c73f-47e5-b7f3-6b9353802a96",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"** topk logits diff **\n"
]
},
{
"data": {
"text/plain": [
"['PhysRevD',\n",
" ' transporting',\n",
" ' teleport',\n",
" ' space',\n",
" ' transit',\n",
" ' transported',\n",
" ' transporter',\n",
" ' transpor',\n",
" ' multiverse',\n",
" ' universes']"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(\"** topk logits diff **\")\n",
"tokenizer.batch_decode(indices[0].unsqueeze(dim=-1))"
]
}
],
"metadata": {
Expand Down

0 comments on commit 21d0c7b

Please sign in to comment.