diff --git a/pyvene/__init__.py b/pyvene/__init__.py index 086e2302..2527fa4a 100644 --- a/pyvene/__init__.py +++ b/pyvene/__init__.py @@ -30,11 +30,13 @@ from .models.interventions import NoiseIntervention from .models.interventions import SigmoidMaskIntervention from .models.interventions import AutoencoderIntervention +from .models.interventions import JumpReLUAutoencoderIntervention from .models.interventions import InterventionOutput # Utils from .models.basic_utils import * +from .models.intervention_utils import _do_intervention_by_swap from .models.intervenable_modelcard import type_to_module_mapping, type_to_dimension_mapping from .models.gpt2.modelings_intervenable_gpt2 import create_gpt2 from .models.gpt2.modelings_intervenable_gpt2 import create_gpt2_lm diff --git a/pyvene/models/interventions.py b/pyvene/models/interventions.py index a839d87a..1dcac4d4 100644 --- a/pyvene/models/interventions.py +++ b/pyvene/models/interventions.py @@ -596,3 +596,47 @@ def forward(self, base, source, subspaces=None): def __str__(self): return f"AutoencoderIntervention()" + + +class JumpReLUAutoencoderIntervention(TrainableIntervention): + """Interchange intervention on JumpReLU SAE's latent subspaces""" + def __init__(self, **kwargs): + # Note that we initialise these to zeros because we're loading in pre-trained weights. + # If you want to train your own SAEs then we recommend using blah + super().__init__(**kwargs, keep_last_dim=True) + self.W_enc = torch.nn.Parameter(torch.zeros(self.embed_dim, kwargs["low_rank_dimension"])) + self.W_dec = torch.nn.Parameter(torch.zeros(kwargs["low_rank_dimension"], self.embed_dim)) + self.threshold = torch.nn.Parameter(torch.zeros(kwargs["low_rank_dimension"])) + self.b_enc = torch.nn.Parameter(torch.zeros(kwargs["low_rank_dimension"])) + self.b_dec = torch.nn.Parameter(torch.zeros(self.embed_dim)) + + def encode(self, input_acts): + pre_acts = input_acts @ self.W_enc + self.b_enc + mask = (pre_acts > self.threshold) + acts = mask * torch.nn.functional.relu(pre_acts) + return acts + + def decode(self, acts): + return acts @ self.W_dec + self.b_dec + + def forward(self, base, source=None, subspaces=None): + # generate latents for base and source runs. + base_latent = self.encode(base) + source_latent = self.encode(source) + # intervention. + intervened_latent = _do_intervention_by_swap( + base_latent, + source_latent, + "interchange", + self.interchange_dim, + subspaces, + subspace_partition=self.subspace_partition, + use_fast=self.use_fast, + ) + # decode intervened latent. + recon = self.decode(intervened_latent) + return recon + + def __str__(self): + return f"JumpReLUAutoencoderIntervention()" + diff --git a/tutorials/basic_tutorials/Sparse_Autoencoder.ipynb b/tutorials/basic_tutorials/Sparse_Autoencoder.ipynb index 1f3bbcfa..c3bd399e 100644 --- a/tutorials/basic_tutorials/Sparse_Autoencoder.ipynb +++ b/tutorials/basic_tutorials/Sparse_Autoencoder.ipynb @@ -18,13 +18,13 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 1, "id": "e5d14f0a-b02d-4d1f-863a-dbb1e475e264", "metadata": {}, "outputs": [], "source": [ "__author__ = \"Zhengxuan Wu\"\n", - "__version__ = \"08/07/2024\"" + "__version__ = \"09/23/2024\"" ] }, { @@ -50,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, "id": "dd197c1f-71b5-4379-a9dd-2f6ff27083f6", "metadata": {}, "outputs": [ @@ -75,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 4, "id": "209bfc46-7685-4e66-975f-3280ed516b52", "metadata": {}, "outputs": [], @@ -85,7 +85,8 @@ " SourcelessIntervention,\n", " TrainableIntervention,\n", " DistributedRepresentationIntervention,\n", - " CollectIntervention\n", + " CollectIntervention,\n", + " JumpReLUAutoencoderIntervention\n", ")\n", "\n", "from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer\n", @@ -108,28 +109,14 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "id": "a6e7e7fb-5e73-4711-b378-bc1b04ab1e7f", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "7e4aaac37998428dbe22cc95595c3fcc", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "model.safetensors.index.json: 0%| | 0.00/24.2k [00:00base\": (11, 14)},\n", + " # the SAE latent dimension mapping to the time travel concept (\"10004\")\n", + " subspaces=[10004],\n", + " output_original_output=True\n", + ")\n", + "logits_diff = intervened_outputs.logits[:,-1] - original_outputs.logits[:,-1]\n", + "values, indices = logits_diff.topk(k=10, sorted=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "57b8c19a-c73f-47e5-b7f3-6b9353802a96", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "** topk logits diff **\n" + ] + }, + { + "data": { + "text/plain": [ + "['PhysRevD',\n", + " ' transporting',\n", + " ' teleport',\n", + " ' space',\n", + " ' transit',\n", + " ' transported',\n", + " ' transporter',\n", + " ' transpor',\n", + " ' multiverse',\n", + " ' universes']" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(\"** topk logits diff **\")\n", + "tokenizer.batch_decode(indices[0].unsqueeze(dim=-1))" + ] } ], "metadata": {