From d6ab70a41bc48d7be6caced2d66f2e4f8a4d5af5 Mon Sep 17 00:00:00 2001 From: UFO-101 <47218308+UFO-101@users.noreply.github.com> Date: Tue, 13 Aug 2024 17:28:58 -0700 Subject: [PATCH 1/2] Improve attention masking (#699) * Allow attention_mask to override the default mask in HookedTransformer.forward(). * Add attention_mask argument to loss_fn() and lm_cross_entropy_loss() and adjust the cross entropy calculation to ignore masked (padding) tokens. --------- Co-authored-by: Bryce Meyer --- tests/integration/test_attention_mask.py | 32 ++++++++++++++++++++ tests/integration/test_cross_entropy_loss.py | 32 ++++++++++++++++++++ transformer_lens/HookedTransformer.py | 24 +++++++++++---- transformer_lens/utils.py | 15 ++++++++- 4 files changed, 96 insertions(+), 7 deletions(-) create mode 100644 tests/integration/test_cross_entropy_loss.py diff --git a/tests/integration/test_attention_mask.py b/tests/integration/test_attention_mask.py index 6b0951f5c..fc4fe41c5 100644 --- a/tests/integration/test_attention_mask.py +++ b/tests/integration/test_attention_mask.py @@ -45,3 +45,35 @@ def attn_hook(attn, hook): ] model.run_with_hooks(input, fwd_hooks=fwd_hooks) + + +def test_masked_tokens(): + """Test that masking tokens works as expected.""" + MODEL = "solu-1l" + prompts = [ + "Hello world!", + "The quick brown fox jumps over the lazy dog.", + ] + model = HookedTransformer.from_pretrained(MODEL) + tokens = model.to_tokens(prompts) + + # Part 1: If the mask is all ones, the output should be the same as if there was no mask. + full_mask = torch.ones_like(tokens) + no_mask_out = model(tokens) + full_mask_out = model(tokens, attention_mask=full_mask) + assert torch.allclose(no_mask_out, full_mask_out), "Full mask should be equivalent to no mask" + + # Part 2: If the mask has a column of zeros, the output should be the same as if that token + # position was removed from the input. + remove_tok_idx = 2 + edited_tokens = torch.cat([tokens[:, :remove_tok_idx], tokens[:, remove_tok_idx + 1 :]], dim=1) + edited_mask = full_mask.clone() + edited_mask[:, remove_tok_idx] = 0 + edited_no_mask_out = model(edited_tokens) + edited_mask_out = model(tokens, attention_mask=edited_mask) + edited_mask_out = torch.cat( + [edited_mask_out[:, :remove_tok_idx], edited_mask_out[:, remove_tok_idx + 1 :]], dim=1 + ) + assert torch.allclose( + edited_no_mask_out, edited_mask_out, atol=1e-4 + ), "Edited mask should be equivalent to no mask" diff --git a/tests/integration/test_cross_entropy_loss.py b/tests/integration/test_cross_entropy_loss.py new file mode 100644 index 000000000..040070031 --- /dev/null +++ b/tests/integration/test_cross_entropy_loss.py @@ -0,0 +1,32 @@ +import torch + +from transformer_lens.HookedTransformer import HookedTransformer + + +def test_cross_entropy_attention_mask(): + """Check that adding a bunch of masked tokens to the input does not change the loss.""" + MODEL = "solu-1l" + model = HookedTransformer.from_pretrained(MODEL) + + # Step 1: Get the default loss on a prompt + prompt = ["The quick brown fox jumps over the lazy dog."] + default_tokens = model.to_tokens(prompt) + default_attention_mask = torch.ones_like(default_tokens) + default_loss = model(default_tokens, return_type="loss") + ones_mask_loss = model( + default_tokens, attention_mask=default_attention_mask, return_type="loss" + ) + assert torch.allclose(default_loss, ones_mask_loss, atol=1e-6) + + # Step 2: Get the loss when we add some extra tokens to the input and set their attention mask + # to zero + extra_prompt = ["Lorem ipsum dolor sit amet, consectetur adipiscing elit."] + extra_tokens = model.to_tokens(extra_prompt) + extra_zeros_attention_mask = torch.zeros_like(extra_tokens) + + combined_tokens = torch.cat([default_tokens, extra_tokens], dim=1) + combined_attention_mask = torch.cat([default_attention_mask, extra_zeros_attention_mask], dim=1) + combined_masked_loss = model( + combined_tokens, attention_mask=combined_attention_mask, return_type="loss" + ) + assert torch.allclose(default_loss, combined_masked_loss) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 164e99803..8ee2e74f8 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -248,6 +248,7 @@ def input_to_embed( input: Union[str, List[str], Int[torch.Tensor, "batch pos"]], prepend_bos: Optional[Union[bool, None]] = USE_DEFAULT_VALUE, padding_side: Optional[Union[Literal["left", "right"], None]] = USE_DEFAULT_VALUE, + attention_mask: Optional[torch.Tensor] = None, past_kv_cache: Optional[HookedTransformerKeyValueCache] = None, ) -> Tuple[ Float[torch.Tensor, "batch pos d_model"], # residual @@ -284,7 +285,15 @@ def input_to_embed( if tokens.device.type != self.cfg.device: tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg)) - if (self.tokenizer and self.tokenizer.padding_side == "left") or past_kv_cache is not None: + if attention_mask is not None: + assert attention_mask.shape == tokens.shape, ( + f"Attention mask shape {attention_mask.shape} does not match tokens shape " + f"{tokens.shape}" + ) + attention_mask = attention_mask.to(devices.get_device_for_block_index(0, self.cfg)) + elif ( + self.tokenizer and self.tokenizer.padding_side == "left" + ) or past_kv_cache is not None: # If the padding side is left or we are using caching, we need to compute the attention # mask for the adjustment of absolute positional embeddings and attention masking so # that pad tokens are not attended. @@ -489,9 +498,10 @@ def forward( shortformer_pos_embed: Optional[Float[torch.Tensor, "batch pos d_model"]]: Positional embedding for shortformer models. Only use if start_at_layer is not None and self.cfg.positional_embedding_type == "shortformer". - attention_mask: Optional[torch.Tensor]: The attention mask for padded tokens. Only use - if start_at_layer is not None and (self.tokenizer.padding_side == "left" or - past_kv_cache is not None). + attention_mask: Optional[torch.Tensor]: Override the attention mask used to ignore + padded tokens. If start_at_layer is not None and (self.tokenizer.padding_side == + "left" or past_kv_cache is not None), this should be passed as the attention mask + is not computed automatically. Defaults to None. stop_at_layer Optional[int]: If not None, stop the forward pass at the specified layer. Exclusive - ie, stop_at_layer = 0 will only run the embedding layer, stop_at_layer = 1 will run the embedding layer and the first transformer block, etc. Supports @@ -523,6 +533,7 @@ def forward( input, prepend_bos=prepend_bos, padding_side=padding_side, + attention_mask=attention_mask, past_kv_cache=past_kv_cache, ) else: @@ -576,7 +587,7 @@ def forward( assert ( tokens is not None ), "tokens must be passed in if return_type is 'loss' or 'both'" - loss = self.loss_fn(logits, tokens, per_token=loss_per_token) + loss = self.loss_fn(logits, tokens, attention_mask, per_token=loss_per_token) if return_type == "loss": return loss elif return_type == "both": @@ -589,6 +600,7 @@ def loss_fn( self, logits: Float[torch.Tensor, "batch pos d_vocab"], tokens: Int[torch.Tensor, "batch pos"], + attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, per_token: bool = False, ): """Wrapper around `utils.lm_cross_entropy_loss`. @@ -597,7 +609,7 @@ def loss_fn( """ if tokens.device != logits.device: tokens = tokens.to(logits.device) - return utils.lm_cross_entropy_loss(logits, tokens, per_token) + return utils.lm_cross_entropy_loss(logits, tokens, attention_mask, per_token) @overload def run_with_cache( diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index f5d3bacbe..7e5828e14 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -115,6 +115,7 @@ def to_numpy(tensor): def lm_cross_entropy_loss( logits: Float[torch.Tensor, "batch pos d_vocab"], tokens: Int[torch.Tensor, "batch pos"], + attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, per_token: bool = False, ) -> Union[Float[torch.Tensor, ""], Float[torch.Tensor, "batch pos"]]: """Cross entropy loss for the language model, gives the loss for predicting the NEXT token. @@ -122,6 +123,8 @@ def lm_cross_entropy_loss( Args: logits (torch.Tensor): Logits. Shape [batch, pos, d_vocab] tokens (torch.Tensor[int64]): Input tokens. Shape [batch, pos] + attention_mask (torch.Tensor[int64], optional): Attention mask. Shape [batch, pos]. Used to + mask out padding tokens. Defaults to None. per_token (bool, optional): Whether to return the log probs predicted for the correct token, or the loss (ie mean of the predicted log probs). Note that the returned array has shape [batch, seq-1] as we cannot predict the first token (alternately, we ignore the final logit). Defaults to False. """ log_probs = F.log_softmax(logits, dim=-1) @@ -129,10 +132,20 @@ def lm_cross_entropy_loss( # Offsets needed because we're predicting the NEXT token (this means the final logit is meaningless) # None and [..., 0] needed because the tensor used in gather must have the same rank. predicted_log_probs = log_probs[..., :-1, :].gather(dim=-1, index=tokens[..., 1:, None])[..., 0] + + if attention_mask is not None: + # Ignore token positions which are masked out or where the next token is masked out + # (generally padding tokens) + next_token_mask = torch.logical_and(attention_mask[:, :-1], attention_mask[:, 1:]) + predicted_log_probs *= next_token_mask + n_tokens = next_token_mask.sum().item() + else: + n_tokens = predicted_log_probs.numel() + if per_token: return -predicted_log_probs else: - return -predicted_log_probs.mean() + return -predicted_log_probs.sum() / n_tokens def lm_accuracy( From 135adcee13914860a68786ffeec2822a954cbdfc Mon Sep 17 00:00:00 2001 From: Min Cai <781391411@qq.com> Date: Wed, 14 Aug 2024 08:54:51 +0800 Subject: [PATCH 2/2] add a demo for Patchscopes and Generation with Patching (#692) * add a demo for Patchscopes and Generation with Patching * added pathscopes generation demo to tests * ignored a couple cells --------- Co-authored-by: Bryce Meyer --- .github/workflows/checks.yml | 1 + demos/Patchscopes_Generation_Demo.ipynb | 3776 +++++++++++++++++++++++ 2 files changed, 3777 insertions(+) create mode 100644 demos/Patchscopes_Generation_Demo.ipynb diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index ffd8b1136..1b71d373e 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -133,6 +133,7 @@ jobs: - "Main_Demo" # - "No_Position_Experiment" - "Othello_GPT" + - "Patchscopes_Generation_Demo" # - "T5" steps: - uses: actions/checkout@v3 diff --git a/demos/Patchscopes_Generation_Demo.ipynb b/demos/Patchscopes_Generation_Demo.ipynb new file mode 100644 index 000000000..49c4655d4 --- /dev/null +++ b/demos/Patchscopes_Generation_Demo.ipynb @@ -0,0 +1,3776 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Patchscopes & Generation with Patching\n", + "\n", + "This notebook contains a demo for Patchscopes (https://arxiv.org/pdf/2401.06102) and demonstrates how to generate multiple tokens with patching. Since there're also some applications in [Patchscopes](##Patchscopes-pipeline) that require generating multiple tokens with patching, I think it's suitable to put both of them in the same notebook. Additionally, generation with patching can be well-described using Patchscopes. Therefore, I simply implement it with the Patchscopes pipeline (see [here](##Generation-with-patching))." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup (Ignore)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", + "\n", + "DEBUG_MODE = False\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "try:\n", + " import google.colab\n", + "\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + "except:\n", + " IN_COLAB = False\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " from IPython import get_ipython\n", + "\n", + " ipython = get_ipython()\n", + " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " ipython.magic(\"load_ext autoreload\")\n", + " ipython.magic(\"autoreload 2\")\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", + " %pip install transformer_lens\n", + " %pip install torchtyping\n", + " # Install my janky personal plotting utils\n", + " %pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", + " # Install another version of node that makes PySvelte work way faster\n", + " %pip install circuitsvis\n", + " # Needed for PySvelte to work, v3 came out and broke things...\n", + " %pip install typeguard==2.13.3\n", + "\n", + "import torch\n", + "from typing import List, Callable, Tuple, Union\n", + "from functools import partial\n", + "from jaxtyping import Float\n", + "from transformer_lens import HookedTransformer\n", + "from transformer_lens.ActivationCache import ActivationCache\n", + "import transformer_lens.utils as utils\n", + "from transformer_lens.hook_points import (\n", + " HookPoint,\n", + ") # Hooking utilities" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Helper Funcs\n", + "\n", + "A helper function to plot logit lens" + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.graph_objects as go\n", + "import numpy as np\n", + "\n", + "# Parameters\n", + "num_layers = 5\n", + "seq_len = 10\n", + "\n", + "# Create a matrix of tokens for demonstration\n", + "tokens = np.array([[\"token_{}_{}\".format(i, j) for j in range(seq_len)] for i in range(num_layers)])[::-1]\n", + "values = np.random.rand(num_layers, seq_len)\n", + "orig_tokens = ['Token {}'.format(i) for i in range(seq_len)]\n", + "\n", + "def draw_logit_lens(num_layers, seq_len, orig_tokens, tokens, values):\n", + " # Create the heatmap\n", + " fig = go.Figure(data=go.Heatmap(\n", + " z=values,\n", + " x=orig_tokens,\n", + " y=['Layer {}'.format(i) for i in range(num_layers)][::-1],\n", + " colorscale='Blues',\n", + " showscale=True,\n", + " colorbar=dict(title='Value')\n", + " ))\n", + "\n", + " # Add text annotations\n", + " annotations = []\n", + " for i in range(num_layers):\n", + " for j in range(seq_len):\n", + " annotations.append(\n", + " dict(\n", + " x=j, y=i,\n", + " text=tokens[i, j],\n", + " showarrow=False,\n", + " font=dict(color='white')\n", + " )\n", + " )\n", + "\n", + " fig.update_layout(\n", + " annotations=annotations,\n", + " xaxis=dict(side='top'),\n", + " yaxis=dict(autorange='reversed'),\n", + " margin=dict(l=50, r=50, t=100, b=50),\n", + " width=1000,\n", + " height=600,\n", + " plot_bgcolor='white'\n", + " )\n", + "\n", + " # Show the plot\n", + " fig.show()\n", + "# draw_logit_lens(num_layers, seq_len, orig_tokens, tokens, values)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Preparation" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded pretrained model gpt2-small into HookedTransformer\n" + ] + }, + { + "data": { + "text/plain": [ + "HookedTransformer(\n", + " (embed): Embed()\n", + " (hook_embed): HookPoint()\n", + " (pos_embed): PosEmbed()\n", + " (hook_pos_embed): HookPoint()\n", + " (blocks): ModuleList(\n", + " (0-11): 12 x TransformerBlock(\n", + " (ln1): LayerNormPre(\n", + " (hook_scale): HookPoint()\n", + " (hook_normalized): HookPoint()\n", + " )\n", + " (ln2): LayerNormPre(\n", + " (hook_scale): HookPoint()\n", + " (hook_normalized): HookPoint()\n", + " )\n", + " (attn): Attention(\n", + " (hook_k): HookPoint()\n", + " (hook_q): HookPoint()\n", + " (hook_v): HookPoint()\n", + " (hook_z): HookPoint()\n", + " (hook_attn_scores): HookPoint()\n", + " (hook_pattern): HookPoint()\n", + " (hook_result): HookPoint()\n", + " )\n", + " (mlp): MLP(\n", + " (hook_pre): HookPoint()\n", + " (hook_post): HookPoint()\n", + " )\n", + " (hook_attn_in): HookPoint()\n", + " (hook_q_input): HookPoint()\n", + " (hook_k_input): HookPoint()\n", + " (hook_v_input): HookPoint()\n", + " (hook_mlp_in): HookPoint()\n", + " (hook_attn_out): HookPoint()\n", + " (hook_mlp_out): HookPoint()\n", + " (hook_resid_pre): HookPoint()\n", + " (hook_resid_mid): HookPoint()\n", + " (hook_resid_post): HookPoint()\n", + " )\n", + " )\n", + " (ln_final): LayerNormPre(\n", + " (hook_scale): HookPoint()\n", + " (hook_normalized): HookPoint()\n", + " )\n", + " (unembed): Unembed()\n", + ")" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "# I'm using an M2 macbook air, so I use CPU for better support\n", + "model = HookedTransformer.from_pretrained(\"gpt2-small\", device=\"cpu\")\n", + "model.eval()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Patchscopes Definition\n", + "\n", + "Here we first wirte down the formal definition decribed in the paper https://arxiv.org/pdf/2401.06102.\n", + "\n", + "The representations are:\n", + "\n", + "source: (S, i, M, l), where S is the source prompt, i is the source position, M is the source model, and l is the source layer.\n", + "\n", + "target: (T,i*,f,M*,l*), where T is the target prompt, i* is the target position, M* is the target model, l* is the target layer, and f is the mapping function that takes the original hidden states as input and output the target hidden states\n", + "\n", + "By defulat, S = T, i = i*, M = M*, l = l*, f = identity function" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Patchscopes Pipeline\n", + "\n", + "### Get hidden representation from the source model\n", + "\n", + "1. We first need to extract the source hidden states from model M at position i of layer l with prompt S. In TransformerLens, we can do this using run_with_cache.\n", + "2. Then, we map the source representation with a function f, and feed the hidden representation to the target position using a hook. Specifically, we focus on residual stream (resid_post), whereas you can manipulate more fine-grainedly with TransformerLens\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "prompts = [\"Patchscopes is a nice tool to inspect hidden representation of language model\"]\n", + "input_tokens = model.to_tokens(prompts)\n", + "clean_logits, clean_cache = model.run_with_cache(input_tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "def get_source_representation(prompts: List[str], layer_id: int, model: HookedTransformer, pos_id: Union[int, List[int]]=None) -> torch.Tensor:\n", + " \"\"\"Get source hidden representation represented by (S, i, M, l)\n", + " \n", + " Args:\n", + " - prompts (List[str]): a list of source prompts\n", + " - layer_id (int): the layer id of the model\n", + " - model (HookedTransformer): the source model\n", + " - pos_id (Union[int, List[int]]): the position id(s) of the model, if None, return all positions\n", + "\n", + " Returns:\n", + " - source_rep (torch.Tensor): the source hidden representation\n", + " \"\"\"\n", + " input_tokens = model.to_tokens(prompts)\n", + " _, cache = model.run_with_cache(input_tokens)\n", + " layer_name = \"blocks.{id}.hook_resid_post\"\n", + " layer_name = layer_name.format(id=layer_id)\n", + " if pos_id is None:\n", + " return cache[layer_name][:, :, :]\n", + " else:\n", + " return cache[layer_name][:, pos_id, :]" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "source_rep = get_source_representation(\n", + " prompts=[\"Patchscopes is a nice tool to inspect hidden representation of language model\"],\n", + " layer_id=2,\n", + " model=model,\n", + " pos_id=5\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Feed the representation to the target position\n", + "\n", + "First we need to map the representation using mapping function f, and then feed the target representation to the target position represented by (T,i*,f,M*,l*)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "# here we use an identity function for demonstration purposes\n", + "def identity_function(source_rep: torch.Tensor) -> torch.Tensor:\n", + " return source_rep" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "# recall the target representation (T,i*,f,M*,l*), and we also need the hidden representation from our source model (S, i, M, l)\n", + "def feed_source_representation(source_rep: torch.Tensor, prompt: List[str], f: Callable, model: HookedTransformer, layer_id: int, pos_id: Union[int, List[int]]=None) -> ActivationCache:\n", + " \"\"\"Feed the source hidden representation to the target model\n", + " \n", + " Args:\n", + " - source_rep (torch.Tensor): the source hidden representation\n", + " - prompt (List[str]): the target prompt\n", + " - f (Callable): the mapping function\n", + " - model (HookedTransformer): the target model\n", + " - layer_id (int): the layer id of the target model\n", + " - pos_id (Union[int, List[int]]): the position id(s) of the target model, if None, return all positions\n", + " \"\"\"\n", + " mapped_rep = f(source_rep)\n", + " # similar to what we did for activation patching, we need to define a function to patch the hidden representation\n", + " def resid_ablation_hook(\n", + " value: Float[torch.Tensor, \"batch pos d_resid\"],\n", + " hook: HookPoint\n", + " ) -> Float[torch.Tensor, \"batch pos d_resid\"]:\n", + " # print(f\"Shape of the value tensor: {value.shape}\")\n", + " # print(f\"Shape of the hidden representation at the target position: {value[:, pos_id, :].shape}\")\n", + " value[:, pos_id, :] = mapped_rep\n", + " return value\n", + " \n", + " input_tokens = model.to_tokens(prompt)\n", + "\n", + " logits = model.run_with_hooks(\n", + " input_tokens,\n", + " return_type=\"logits\",\n", + " fwd_hooks=[(\n", + " utils.get_act_name(\"resid_post\", layer_id),\n", + " resid_ablation_hook\n", + " )]\n", + " )\n", + " \n", + " return logits" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "patched_logits = feed_source_representation(\n", + " source_rep=source_rep,\n", + " prompt=prompts,\n", + " pos_id=3,\n", + " f=identity_function,\n", + " model=model,\n", + " layer_id=2\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([[ 3.5811, 3.5322, 2.6463, ..., -4.3504, -1.7939, 3.3541]],\n", + " grad_fn=),\n", + " tensor([[ 3.2431, 3.2708, 1.9591, ..., -4.2666, -2.2141, 3.4965]],\n", + " grad_fn=))" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "clean_logits[:, 5], patched_logits[:, 5]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generation with Patching\n", + "\n", + "In the last step, we've implemented the basic version of Patchscopes where we can only run one single forward pass. Let's now unlock the power by allowing it to generate multiple tokens!" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_with_patching(model: HookedTransformer, prompts: List[str], target_f: Callable, max_new_tokens: int = 50):\n", + " temp_prompts = prompts\n", + " input_tokens = model.to_tokens(temp_prompts)\n", + " for _ in range(max_new_tokens):\n", + " logits = target_f(\n", + " prompt=temp_prompts,\n", + " )\n", + " next_tok = torch.argmax(logits[:, -1, :])\n", + " input_tokens = torch.cat((input_tokens, next_tok.view(input_tokens.size(0), 1)), dim=1)\n", + " temp_prompts = model.to_string(input_tokens)\n", + "\n", + " return model.to_string(input_tokens)[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<|endoftext|>Patchscopes is a nice tool to inspect hidden representation of language model file bit file\n" + ] + } + ], + "source": [ + "prompts = [\"Patchscopes is a nice tool to inspect hidden representation of language model\"]\n", + "input_tokens = model.to_tokens(prompts)\n", + "target_f = partial(\n", + " feed_source_representation,\n", + " source_rep=source_rep,\n", + " pos_id=-1,\n", + " f=identity_function,\n", + " model=model,\n", + " layer_id=2\n", + ")\n", + "gen = generate_with_patching(model, prompts, target_f, max_new_tokens=3)\n", + "print(gen)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Patchscopes is a nice tool to inspect hidden representation of language model.\n", + "\n", + "It is a simple tool to inspect hidden representation of language model.\n", + "\n", + "It is a simple tool to inspect hidden representation of language model.\n", + "\n", + "It is a simple tool to inspect hidden representation of language model.\n", + "\n", + "It is\n" + ] + } + ], + "source": [ + "# Original generation\n", + "print(model.generate(prompts[0], verbose=False, max_new_tokens=50, do_sample=False))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Application Examples" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Logit Lens\n", + "\n", + "For Logit Lens, the configuration is l* ← L*. Here, L* is the last layer." + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "metadata": {}, + "outputs": [], + "source": [ + "token_list = []\n", + "value_list = []\n", + "\n", + "def identity_function(source_rep: torch.Tensor) -> torch.Tensor:\n", + " return source_rep\n", + "\n", + "for source_layer_id in range(12):\n", + " # Prepare source representation\n", + " source_rep = get_source_representation(\n", + " prompts=[\"Patchscopes is a nice tool to inspect hidden representation of language model\"],\n", + " layer_id=source_layer_id,\n", + " model=model,\n", + " pos_id=None\n", + " )\n", + "\n", + " logits = feed_source_representation(\n", + " source_rep=source_rep,\n", + " prompt=[\"Patchscopes is a nice tool to inspect hidden representation of language model\"],\n", + " f=identity_function,\n", + " model=model,\n", + " layer_id=11\n", + " )\n", + " token_list.append([model.to_string(token_id.item()) for token_id in logits.argmax(dim=-1).squeeze()])\n", + " value_list.append([value for value in torch.max(logits.softmax(dim=-1), dim=-1)[0].detach().squeeze().numpy()])" + ] + }, + { + "cell_type": "code", + "execution_count": 109, + "metadata": {}, + "outputs": [], + "source": [ + "token_list = np.array(token_list[::-1])\n", + "value_list = np.array(value_list[::-1])" + ] + }, + { + "cell_type": "code", + "execution_count": 110, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "colorbar": { + "title": { + "text": "Value" + } + }, + "colorscale": [ + [ + 0, + "rgb(247,251,255)" + ], + [ + 0.125, + "rgb(222,235,247)" + ], + [ + 0.25, + "rgb(198,219,239)" + ], + [ + 0.375, + "rgb(158,202,225)" + ], + [ + 0.5, + "rgb(107,174,214)" + ], + [ + 0.625, + "rgb(66,146,198)" + ], + [ + 0.75, + "rgb(33,113,181)" + ], + [ + 0.875, + "rgb(8,81,156)" + ], + [ + 1, + "rgb(8,48,107)" + ] + ], + "showscale": true, + "type": "heatmap", + "x": [ + "<|endoftext|>", + "Patch", + "sc", + "opes", + " is", + " a", + " nice", + " tool", + " to", + " inspect", + " hidden", + " representation", + " of", + " language", + " model" + ], + "y": [ + "Layer 11", + "Layer 10", + "Layer 9", + "Layer 8", + "Layer 7", + "Layer 6", + "Layer 5", + "Layer 4", + "Layer 3", + "Layer 2", + "Layer 1", + "Layer 0" + ], + "z": [ + [ + 0.34442219138145447, + 0.9871702790260315, + 0.3734475076198578, + 0.9830440878868103, + 0.4042338728904724, + 0.09035539627075195, + 0.8022230863571167, + 0.5206465125083923, + 0.14175501465797424, + 0.9898471236228943, + 0.9606538414955139, + 0.9691148996353149, + 0.662227988243103, + 0.9815096855163574, + 0.9055094718933105 + ], + [ + 0.08009976148605347, + 0.99101722240448, + 0.45667293667793274, + 0.40307697653770447, + 0.49327367544174194, + 0.08549172431230545, + 0.7428992390632629, + 0.8611035943031311, + 0.1983162760734558, + 0.9246276021003723, + 0.8956946730613708, + 0.8638046383857727, + 0.8365117311477661, + 0.9618501663208008, + 0.9175702333450317 + ], + [ + 0.02691030502319336, + 0.9732530117034912, + 0.19330987334251404, + 0.381843239068985, + 0.33808818459510803, + 0.07934993505477905, + 0.3974476158618927, + 0.7191767692565918, + 0.24212224781513214, + 0.7858667373657227, + 0.866357684135437, + 0.6622256636619568, + 0.8740373849868774, + 0.947133481502533, + 0.8450764417648315 + ], + [ + 0.027061497792601585, + 0.9609430432319641, + 0.2772334814071655, + 0.20079827308654785, + 0.2932577431201935, + 0.1255684345960617, + 0.32114332914352417, + 0.6489707827568054, + 0.2919656038284302, + 0.18173590302467346, + 0.635391891002655, + 0.5701303482055664, + 0.8785448670387268, + 0.8575655221939087, + 0.6919821500778198 + ], + [ + 0.026887305080890656, + 0.9309146404266357, + 0.44758421182632446, + 0.24046003818511963, + 0.28474941849708557, + 0.20104897022247314, + 0.5028793811798096, + 0.48273345828056335, + 0.2584459185600281, + 0.36538586020469666, + 0.20586784183979034, + 0.3072110712528229, + 0.9045845866203308, + 0.5042338371276855, + 0.4879302978515625 + ], + [ + 0.0265483595430851, + 0.9315882921218872, + 0.41395631432533264, + 0.2468952238559723, + 0.35624295473098755, + 0.21814416348934174, + 0.6175792813301086, + 0.7821283340454102, + 0.28484007716178894, + 0.3186572194099426, + 0.16824035346508026, + 0.5927833914756775, + 0.8808191418647766, + 0.5171196460723877, + 0.2029583901166916 + ], + [ + 0.026423994451761246, + 0.898944079875946, + 0.32038140296936035, + 0.44839850068092346, + 0.2796024978160858, + 0.20586445927619934, + 0.6313580274581909, + 0.87591552734375, + 0.18971839547157288, + 0.3038368225097656, + 0.36893585324287415, + 0.5965255498886108, + 0.7505314946174622, + 0.5989011526107788, + 0.10610682517290115 + ], + [ + 0.026437079533934593, + 0.6845366358757019, + 0.3912840485572815, + 0.37950050830841064, + 0.5224342346191406, + 0.2038283497095108, + 0.3475077748298645, + 0.647609293460846, + 0.11305152624845505, + 0.4017726182937622, + 0.4405157268047333, + 0.533568799495697, + 0.5206188559532166, + 0.2670389711856842, + 0.08740855008363724 + ], + [ + 0.026673221960663795, + 0.36045604944229126, + 0.27727553248405457, + 0.4515568017959595, + 0.5681671500205994, + 0.36901071667671204, + 0.5300043821334839, + 0.494934618473053, + 0.3656132221221924, + 0.40456005930900574, + 0.2656775712966919, + 0.2756248712539673, + 0.517121434211731, + 0.3028433322906494, + 0.09847757965326309 + ], + [ + 0.026949577033519745, + 0.3112040162086487, + 0.22643150389194489, + 0.7095355987548828, + 0.5966493487358093, + 0.4613777995109558, + 0.8436885476112366, + 0.4194002151489258, + 0.22365105152130127, + 0.4558623731136322, + 0.32150164246559143, + 0.4018287658691406, + 0.8275868892669678, + 0.3780366778373718, + 0.19973652064800262 + ], + [ + 0.027445374056696892, + 0.3283821940422058, + 0.5192154049873352, + 0.1790430098772049, + 0.6429017782211304, + 0.3577035665512085, + 0.6037949919700623, + 0.5884966254234314, + 0.18566730618476868, + 0.3142710030078888, + 0.15301460027694702, + 0.3585647940635681, + 0.4576294720172882, + 0.1486930102109909, + 0.13506801426410675 + ], + [ + 0.062298569828271866, + 0.24093002080917358, + 0.16585318744182587, + 0.16210544109344482, + 0.449150949716568, + 0.042253680527210236, + 0.11057071387767792, + 0.3447357416152954, + 0.08157400786876678, + 0.13642098009586334, + 0.07241284847259521, + 0.25115686655044556, + 0.084745854139328, + 0.0951341837644577, + 0.1267273873090744 + ] + ] + } + ], + "layout": { + "annotations": [ + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "\n", + "x": 0, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "rawl", + "x": 2, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "opes", + "x": 3, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " not", + "x": 4, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " new", + "x": 5, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " nice", + "x": 6, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "tips", + "x": 7, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " get", + "x": 8, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " inspect", + "x": 9, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " hidden", + "x": 10, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " representation", + "x": 11, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " model", + "x": 14, + "y": 0 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "\n", + "x": 0, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "urry", + "x": 2, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Operator", + "x": 3, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " not", + "x": 4, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " new", + "x": 5, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " nice", + "x": 6, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "tips", + "x": 7, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " get", + "x": 8, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " inspect", + "x": 9, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " hidden", + "x": 10, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " representation", + "x": 11, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " model", + "x": 14, + "y": 1 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "rawl", + "x": 2, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Operator", + "x": 3, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " not", + "x": 4, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " new", + "x": 5, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " nice", + "x": 6, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "tips", + "x": 7, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " keep", + "x": 8, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " inspect", + "x": 9, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " hidden", + "x": 10, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " representation", + "x": 11, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " model", + "x": 14, + "y": 2 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "atch", + "x": 2, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ":", + "x": 3, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " currently", + "x": 4, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " very", + "x": 5, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " thing", + "x": 6, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "tips", + "x": 7, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " keep", + "x": 8, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 9, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " hidden", + "x": 10, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " representation", + "x": 11, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " model", + "x": 14, + "y": 3 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "atch", + "x": 2, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ":", + "x": 3, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " currently", + "x": 4, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " unique", + "x": 5, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " little", + "x": 6, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 7, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " keep", + "x": 8, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " your", + "x": 9, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " hidden", + "x": 10, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " model", + "x": 14, + "y": 4 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "reens", + "x": 2, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ":", + "x": 3, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " currently", + "x": 4, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " unique", + "x": 5, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " little", + "x": 6, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 7, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " keep", + "x": 8, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " your", + "x": 9, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " gem", + "x": 10, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " model", + "x": 14, + "y": 5 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "ree", + "x": 2, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ":", + "x": 3, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " also", + "x": 4, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " unique", + "x": 5, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " little", + "x": 6, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 7, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " keep", + "x": 8, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 9, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " gems", + "x": 10, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " (", + "x": 14, + "y": 6 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Patch", + "x": 1, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "ream", + "x": 2, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ":", + "x": 3, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " currently", + "x": 4, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " powerful", + "x": 5, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " little", + "x": 6, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 7, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " keep", + "x": 8, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " your", + "x": 9, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " gems", + "x": 10, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " language", + "x": 13, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " (", + "x": 14, + "y": 7 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "work", + "x": 1, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "ream", + "x": 2, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "¶", + "x": 3, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " currently", + "x": 4, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " powerful", + "x": 5, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 6, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "kit", + "x": 7, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " help", + "x": 8, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " your", + "x": 9, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 10, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 12, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 13, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " data", + "x": 14, + "y": 8 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "work", + "x": 1, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "ream", + "x": 2, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "¶", + "x": 3, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " a", + "x": 4, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 5, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 6, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 7, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " help", + "x": 8, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " your", + "x": 9, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 10, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 12, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " strings", + "x": 13, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " variables", + "x": 14, + "y": 9 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ",", + "x": 0, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Notes", + "x": 1, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "rew", + "x": 2, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "¶", + "x": 3, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " a", + "x": 4, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 5, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 6, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " for", + "x": 7, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " help", + "x": 8, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " your", + "x": 9, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " items", + "x": 10, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 12, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 13, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " objects", + "x": 14, + "y": 10 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "\n", + "x": 0, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " Notes", + "x": 1, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "rew", + "x": 2, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": "\n", + "x": 3, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " a", + "x": 4, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 5, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " tool", + "x": 6, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " for", + "x": 7, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " help", + "x": 8, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 9, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " files", + "x": 10, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " of", + "x": 11, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " the", + "x": 12, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": " features", + "x": 13, + "y": 11 + }, + { + "font": { + "color": "white" + }, + "showarrow": false, + "text": ".", + "x": 14, + "y": 11 + } + ], + "height": 600, + "margin": { + "b": 50, + "l": 50, + "r": 50, + "t": 100 + }, + "plot_bgcolor": "white", + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "width": 1000, + "xaxis": { + "side": "top" + }, + "yaxis": { + "autorange": "reversed" + } + } + } + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "num_layers = 12\n", + "seq_len = len(token_list[0])\n", + "orig_tokens = [model.to_string(token_id) for token_id in model.to_tokens([\"Patchscopes is a nice tool to inspect hidden representation of language model\"])[0]]\n", + "draw_logit_lens(num_layers, seq_len, orig_tokens, token_list, value_list)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Entity Description\n", + "\n", + "Entity description tries to answer \"how LLMs resolve entity mentions across multiple layers. Concretely, given a subject entity name, such as “the summer Olympics of 1996”, how does the model contextualize the input tokens of the entity and at which layer is it fully resolved?\"\n", + "\n", + "The configuration is l* ← l, i* ← m, and it requires generating multiple tokens. Here m refers to the last position (the position of x)" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "metadata": {}, + "outputs": [], + "source": [ + " # Prepare source representation\n", + "source_rep = get_source_representation(\n", + " prompts=[\"Diana, Princess of Wales\"],\n", + " layer_id=11,\n", + " model=model,\n", + " pos_id=-1\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 115, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generation by patching layer 0:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The \"The \"The \"The \"The \"The \"The \"The \"The\n", + "==============================\n", + "\n", + "Generation by patching layer 1:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The \"The \"The \"The \"The \"The \"The \"The \"The\n", + "==============================\n", + "\n", + "Generation by patching layer 2:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The\n", + "The\n", + "\n", + "\n", + "The\n", + "The\n", + "The\n", + "\n", + "\n", + "The\n", + "The\n", + "==============================\n", + "\n", + "Generation by patching layer 3:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The\n", + "\n", + "\n", + "The\n", + "\n", + "\n", + "The\n", + "\n", + "\n", + "The\n", + "\n", + "\n", + "The\n", + "==============================\n", + "\n", + "Generation by patching layer 4:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The United States\n", + "\n", + "\n", + "The United States\n", + "\n", + "\n", + "The United States\n", + "\n", + "\n", + "==============================\n", + "\n", + "Generation by patching layer 5:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The United States\n", + "\n", + "\n", + "The United States\n", + "\n", + "\n", + "The United States\n", + "\n", + "\n", + "==============================\n", + "\n", + "Generation by patching layer 6:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The United States is the world's most popular and the world's most beautiful.\n", + "\n", + "==============================\n", + "\n", + "Generation by patching layer 7:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The United States is the world's most popular and most beautiful country.\n", + "\n", + "\n", + "\n", + "==============================\n", + "\n", + "Generation by patching layer 8:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The United States is the world's largest exporter of the world's most expensive and\n", + "==============================\n", + "\n", + "Generation by patching layer 9:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The first time I saw the film, I was in the middle of a meeting with\n", + "==============================\n", + "\n", + "Generation by patching layer 10:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\n", + "\n", + "\n", + "The world's most famous actor, actor and producer, Leonardo DiCaprio, has\n", + "==============================\n", + "\n", + "Generation by patching layer 11:\n", + "<|endoftext|>Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x, and the world's largest consumer electronics company, Samsung Electronics Co., Ltd.\n", + "\n", + "\n", + "The\n", + "==============================\n", + "\n" + ] + } + ], + "source": [ + "target_prompt = [\"Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x\"]\n", + "# need to calcualte an absolute position, instead of a relative position\n", + "last_pos_id = len(model.to_tokens(target_prompt)[0]) - 1\n", + "# we need to define the function that takes the generation as input\n", + "for target_layer_id in range(12):\n", + " target_f = partial(\n", + " feed_source_representation,\n", + " source_rep=source_rep,\n", + " pos_id=last_pos_id,\n", + " f=identity_function,\n", + " model=model,\n", + " layer_id=target_layer_id\n", + " )\n", + " gen = generate_with_patching(model, target_prompt, target_f, max_new_tokens=20)\n", + " print(f\"Generation by patching layer {target_layer_id}:\\n{gen}\\n{'='*30}\\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As we can see, maybe the early layers of gpt2-small are doing something related to entity resolution, whereas the late layers are apparently not(?)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Zero-Shot Feature Extraction\n", + "\n", + "Zero-shot Feature Extraction \"Consider factual and com- monsense knowledge represented as triplets (σ,ρ,ω) of a subject (e.g., “United States”), a relation (e.g., “largest city of”), and an object (e.g.,\n", + "“New York City”). We investigate to what extent the object ω can be extracted from the last token representation of the subject σ in an arbitrary input context.\"\n", + "\n", + "The configuration is l∗ ← j′ ∈ [1,...,L∗], i∗ ← m, T ← relation verbalization followed by x" + ] + }, + { + "cell_type": "code", + "execution_count": 359, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Co-founder of company Apple, Steve Jobs, has said that Apple\\'s iPhone 6 and 6 Plus are \"the most important phones'" + ] + }, + "execution_count": 359, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# for a triplet (company Apple, co-founder of, Steve Jobs), we need to first make sure that the object is in the continuation\n", + "source_prompt = \"Co-founder of company Apple\"\n", + "model.generate(source_prompt, verbose=False, max_new_tokens=20, do_sample=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 366, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<|endoftext|>Co-founder of x, co-founder of Google x, co-founder of Facebook x, co-founder of Twitter\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, and co-founder\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to hide\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x has been accused of being a \"fraud\" by the US government.\n", + "\n", + "\n", + "The former\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of xApple, who has been working on the iPhone since 2011, has been working on the iPhone since 2011\n", + "<|endoftext|>Co-founder of x, co-founder of Google x, co-founder of Facebook x, co-founder of Twitter\n", + "<|endoftext|>Co-founder of x, co-founder of Google, co-founder of Facebook, co-founder of Twitter, co\n", + "<|endoftext|>Co-founder of x, co-founder of x, co-founder of x, co-founder of x, co\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes mobile apps for the iPhone, iPad and iPod touch, says he's been\n", + "<|endoftext|>Co-founder of x, a company that makes a lot of things, has been accused of sexual harassment by a former employee\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of xApple, who has been working on the iPhone since 2011, has been working on the iPhone since 2011\n", + "<|endoftext|>Co-founder of x, co-founder of Google x, co-founder of Facebook x, co-founder of Twitter\n", + "<|endoftext|>Co-founder of x, co-founder of Google x, co-founder of Facebook x, co-founder of Twitter\n", + "<|endoftext|>Co-founder of x, co-founder of Google, co-founder of Facebook, co-founder of Twitter, co\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes software for the web, has been accused of using a \"secret\" code\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of x Apple, Steve Jobs, has been accused of being a \"fraud\" by a former employee who\n", + "<|endoftext|>Co-founder of x, co-founder of Google x, co-founder of Facebook x, co-founder of Twitter\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, are the first people\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, and co-founder\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes software for the iPhone, has been arrested in the US.\n", + "\n", + "\n", + "\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of xInsurance, a company that provides insurance for people with disabilities, said he's been in touch with\n", + "<|endoftext|>Co-founder of x, CEO Tim Cook, and co-founder of Facebook, Mark Zuckerberg, have been named to the\n", + "<|endoftext|>Co-founder of x, co-founder of the company Apple, and co-founder of the company Apple, and co\n", + "<|endoftext|>Co-founder of x, who has been a vocal critic of the company's recent decision to cut its workforce, has been\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to hide\n", + "<|endoftext|>Co-founder of x, a company that makes software for the iPhone, has been arrested in the US.\n", + "\n", + "\n", + "\n", + "<|endoftext|>Co-founder of x, a company that makes software for the web, has been arrested in the US.\n", + "\n", + "\n", + "\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of xInsurance, a company that provides insurance for people with disabilities, said he's been in touch with\n", + "<|endoftext|>Co-founder of x, CEO Tim Cook, and co-founder of Facebook x, Mark Zuckerberg, are among the most\n", + "<|endoftext|>Co-founder of x, who has been a vocal critic of the company's iPhone 6 and iPhone 6 Plus, has been\n", + "<|endoftext|>Co-founder of x, who has been a vocal critic of the company's recent decision to cut its workforce, has been\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes software for the iPhone, has been arrested in the US.\n", + "\n", + "\n", + "\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of x Watch, a company that helps people with disabilities, says he's been working on a new app for\n", + "<|endoftext|>Co-founder of x, CEO Tim Cook, and co-founder Steve Jobs.\n", + "\n", + "\n", + "The company's new CEO\n", + "<|endoftext|>Co-founder of x, who has been a vocal critic of the company's recent decision to cut its workforce by half,\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, and co-founder\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been named the world's most valuable person by Forbes.\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes software for the iPhone, has been arrested in the US.\n", + "\n", + "\n", + "\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of x Watch, a company that helps people with disabilities, says he's been working on a new app for\n", + "<|endoftext|>Co-founder of x, the company's new iPhone, is expected to be unveiled in the coming weeks.\n", + "\n", + "\n", + "\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, and co-founder\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, and co-founder\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been named the world's most valuable person by Forbes.\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been arrested in the US for allegedly selling a fake iPhone\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x's, who has been working on the project for over a year, has been working on the project\n", + "<|endoftext|>Co-founder of xInsurance, a company that provides insurance for people with disabilities, said he's been in touch with\n", + "<|endoftext|>Co-founder of x, Apple x, Apple x, Apple x, Apple x, Apple x, Apple x, Apple\n", + "<|endoftext|>Co-founder of x, who is now the CEO of Apple, has been named the new CEO of the company.\n", + "\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company's new product, the\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been named the world's most valuable person by Forbes.\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"bot\" to make a\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the iPhone, has been arrested in the US.\n", + "\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of xbaum.com, who has been a member of the XBAY community for over a decade,\n", + "<|endoftext|>Co-founder of x, Apple x, Apple x, Apple x, Apple x, Apple x, Apple x, Apple\n", + "<|endoftext|>Co-founder of x, who is now the CEO of Apple, has been named the new CEO of the company.\n", + "\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, are both on the\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been named the world's most valuable person by Forbes.\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"bot\" to make a\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"bot\" to make a\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been arrested in the US for allegedly selling a fake iPhone\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x Inc. and co-founder of X.com, Mark Zuckerberg, has been accused of using his\n", + "<|endoftext|>Co-founder of x, Apple x, and Apple x.\n", + "Apple x, Apple x, and Apple x.\n", + "\n", + "<|endoftext|>Co-founder of x, a guest Jan 25th, 2016 1,929 Never a guest1,929Never\n", + "<|endoftext|>Co-founder of x, co-founder of the company, and co-founder of the company, are both on the\n", + "<|endoftext|>Co-founder of x, a company that makes the iPhone, has been named the new CEO of Apple.\n", + "\n", + "\n", + "\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"bot\" to make a\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"bot\" to make a\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been accused of using a \"fake\" name to sell\n", + "<|endoftext|>Co-founder of x, a company that makes smart phones, has been arrested in the US for allegedly stealing $1 million\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n", + "<|endoftext|>Co-founder of x Watch, a company that helps people with disabilities, says he's been working on a new app for\n", + "<|endoftext|>Co-founder of x, the iPhone, the iPhone, the iPhone, the iPhone, the iPhone, the iPhone, the\n", + "<|endoftext|>Co-founder of x, the new, the new, the new, the new, the new, the new, the\n", + "<|endoftext|>Co-founder of x, the company, the company, the company, the company, the company, the company, the\n", + "<|endoftext|>Co-founder of x, the company, the company, the company, the company, the company, the company, the\n", + "<|endoftext|>Co-founder of x, the company has announced that the company has released the company has released the company has released the company\n", + "<|endoftext|>Co-founder of x, the company, the company, the company, the company, the company, the company, the\n", + "<|endoftext|>Co-founder of x, the company, said the company is now offering a new product, but the company has now announced\n", + "<|endoftext|>Co-founder of x, the company that has been working on the iPhone, said the company has been working on the iPhone\n", + "<|endoftext|>Co-founder of x, the company that created the iPhone, said the company is now working on a new product, but\n", + "<|endoftext|>Co-founder of x, a company that makes the iPhone, said that the company is working on a new product that will\n", + "<|endoftext|>Co-founder of x, a company that makes the world's most popular mobile phone, has been arrested in the US.\n", + "<|endoftext|>Co-founder of x, a startup that helps people build apps for the web, has been arrested for allegedly stealing $1\n" + ] + } + ], + "source": [ + "# Still need an aboslute position\n", + "last_pos_id = len(model.to_tokens([\"Co-founder of x\"])[0]) - 1\n", + "target_prompt = [\"Co-founder of x\"]\n", + "\n", + "# Check all the combinations, you'll see that the model is able to generate \"Steve Jobs\" in several continuations\n", + "for source_layer_id in range(12):\n", + " # Prepare source representation, here we can use relative position\n", + " source_rep = get_source_representation(\n", + " prompts=[\"Co-founder of company Apple\"],\n", + " layer_id=source_layer_id,\n", + " model=model,\n", + " pos_id=-1\n", + " )\n", + " for target_layer_id in range(12):\n", + " target_f = partial(\n", + " feed_source_representation,\n", + " source_rep=source_rep,\n", + " prompt=target_prompt,\n", + " f=identity_function,\n", + " model=model,\n", + " pos_id=last_pos_id,\n", + " layer_id=target_layer_id\n", + " )\n", + " gen = generate_with_patching(model, target_prompt, target_f, max_new_tokens=20)\n", + " print(gen)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mechinterp", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}