Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/acces single heads #40

Merged
merged 13 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 81 additions & 11 deletions obvs/patchscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field

import einops
import torch
from nnsight import LanguageModel
from tqdm import tqdm
Expand All @@ -55,6 +56,7 @@ class SourceContext:
prompt: str | torch.Tensor
position: Sequence[int] | None = None
layer: int = -1
head: int | None = None
llinauer marked this conversation as resolved.
Show resolved Hide resolved
model_name: str = "gpt2"
device: str = "cuda" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -128,6 +130,7 @@ def from_source(
position=source.position,
model_name=source.model_name,
layer=source.layer,
head=source.head,
mapping_function=mapping_function or (lambda x: x),
max_new_tokens=max_new_tokens,
device=source.device,
Expand Down Expand Up @@ -168,8 +171,10 @@ def __init__(self, source: SourceContext, target: TargetContext) -> None:

self.tokenizer = self.source_model.tokenizer

self.MODEL_SOURCE, self.LAYER_SOURCE = self.get_model_specifics(self.source.model_name)
self.MODEL_TARGET, self.LAYER_TARGET = self.get_model_specifics(self.target.model_name)
self.MODEL_SOURCE, self.LAYER_SOURCE, self.ATTN_SOURCE, self.HEAD_SOURCE = \
llinauer marked this conversation as resolved.
Show resolved Hide resolved
self.get_model_specifics(self.source.model_name)
self.MODEL_TARGET, self.LAYER_TARGET, self.ATTN_TARGET, self.HEAD_TARGET = \
self.get_model_specifics(self.target.model_name)

self._target_outputs: list[torch.Tensor] = []

Expand All @@ -186,18 +191,34 @@ def source_forward_pass(self) -> None:
# TODO: validate this with non GPT2 & GPTJ models
self.source_model.transformer.wte.output = self.source.soft_prompt

self._source_hidden_state = self.manipulate_source().save()
self.source_output = self.source_model.lm_head.output[0].save()
self._source_hidden_state = self.manipulate_source().detach().save()
self.source_output = self.source_model.lm_head.output[0].detach().save()

def manipulate_source(self) -> torch.Tensor:
"""
Get the hidden state from the source representation.

NB: This is seperated out from the source_forward_pass method to allow for batching.
"""
return getattr(getattr(self.source_model, self.MODEL_SOURCE), self.LAYER_SOURCE)[

# get the specified layer
layer = getattr(getattr(self.source_model, self.MODEL_SOURCE), self.LAYER_SOURCE)[
self.source.layer
].output[0][:, self._source_position, :]
]

# if a head index is given, need to access the ATTN and HEAD components
if self.source.head is not None:
attn = getattr(layer, self.ATTN_SOURCE)
# TODO may not be .input for other models
llinauer marked this conversation as resolved.
Show resolved Hide resolved
head_act = getattr(attn, self.HEAD_SOURCE).input[0][0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we using input instead of output?

My understanding is that patchscope always uses output, and if a researcher needs an input from layer i, they can access the output from layer i-1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is, that the output of the c_attn layer in GPT2Attention is not the same as the input of the c_proj.
c_attn.output gets us the Q,K & Values concatenated together into one tensor. We want the attention layer outputs (sometimes referred to as z-values), which are calculated inbetween the c_attn and c_proj forward calls in the GPT2Attention object. So they are input of c_proj, but not output of c_attn
See GPT2Attention.forward for reference (https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L306)

Copy link
Collaborator

@tvhong tvhong Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I.. see. So, c_proj is the equivalence of W^O in the original transformer paper?

If so, I agree that the concatenated head would be at .attn.c_proj.input.

image

https://arxiv.org/pdf/1706.03762.pdf


# need to reshape the output into the specific heads
head_act = einops.rearrange(head_act,
'batch pos (n_head d_head) -> batch pos n_head d_head',
n_head=attn.num_heads, d_head=attn.head_dim)
return head_act[:, self._source_position, self.source.head, :]

return layer.output[0][:, self._source_position, :]

def map(self) -> None:
"""
Expand Down Expand Up @@ -227,20 +248,69 @@ def target_forward_pass(self) -> None:
self.manipulate_target()

def manipulate_target(self) -> None:
(
getattr(getattr(self.target_model, self.MODEL_TARGET), self.LAYER_TARGET)[
self.target.layer
].output[0][:, self._target_position, :]
) = self._mapped_hidden_state

# get the specified layer
layer = getattr(getattr(self.target_model, self.MODEL_TARGET), self.LAYER_TARGET)[
self.target.layer
]

# if a head index is given, need to access the ATTN and HEAD components
if self.target.head is not None:
attn = getattr(layer, self.ATTN_TARGET)
# TODO may not be .input for other models
concat_head_act = getattr(attn, self.HEAD_TARGET).input[0][0]

# need to reshape the output of head into the specific heads
split_head_act = einops.rearrange(
concat_head_act, 'batch pos (n_head d_head) -> batch pos n_head d_head',
n_head=attn.num_heads, d_head=attn.head_dim
)

# check if the dimensions of the mapped_hidden_state and target head activations match
target_act = split_head_act[:, self._target_position, self.target.head, :]
if self._mapped_hidden_state.shape != target_act.shape:
logger.error('Cannot set activation of head %s in target model with shape'
' %s to patched activation of source model with shape %s!',
self.target.head, list(target_act.shape),
list(self._mapped_hidden_state.shape))
return
llinauer marked this conversation as resolved.
Show resolved Hide resolved
split_head_act[:, self._target_position, self.target.head, :] = self._mapped_hidden_state
else:
layer.output[0][:, self._target_position, :] = self._mapped_hidden_state

self._target_outputs.append(self.target_model.lm_head.output[0].save())
for _ in range(self.target.max_new_tokens - 1):
self._target_outputs.append(self.target_model.lm_head.next().output[0].save())

def check_patchscope_setup(self) -> bool:
""" Check if patchscope is correctly set-up before running """

# head can be int or None, patchscope run is only possible if they have the same type
# TODO: Find out how to do it PEP8 compliant
if type(self.source.head) != type(self.target.head):
logger.error('Cannot run patchscope with source head attribute: %s and target'
' head attribute: %s. Both need to be of the same type.', self.source.head,
self.target.head)
return False

# currently, accessing single head activations is only supported for GPT2LMHead models
if (self.source.head is not None and 'gpt2' not in self.source.model_name or
self.target.head is not None and 'gpt2' not in self.target.model_name):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with other GPT models (e.g., GPTJ)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately not.
GPT-J, despite being similar, uses a different attention implementation (GPTJAttention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py#L100)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to implement a mechanism that works for a range of model architectures

Copy link
Collaborator

@tvhong tvhong Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good news is that we're starting to see a pattern emerge.

I'm thinking we want to have a base ModelAccessor class that looks like:

class ModelAccessor(ABC):
  def get_block_output(position: list[int], layer: int) -> Tensor:
    raise NotImplementedError(...)

  def set_block_output(position: list[int], layer: int) -> None:
    raise NotImplementedError(...)

  def get_head_attn(position: list[int], layer: int, head: list[int]) -> Tensor:
    raise NotImplementedError(...)

  def set_head_attn(position: list[int], layer: int, head: list[int]) -> None:
    raise NotImplementedError(...)

and each model can implement this class.

raise NotImplementedError('Accessing single head activations is currently only'
' implemented for GPT2-style models')

return True

def run(self) -> None:
"""
Run the patchscope
"""

# check before running
if not self.check_patchscope_setup():
logger.error('Abort running patchscope')
return
llinauer marked this conversation as resolved.
Show resolved Hide resolved

self.clear()
self.source_forward_pass()
self.map()
Expand Down
6 changes: 3 additions & 3 deletions obvs/patchscope_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def get_model_specifics(self, model_name):
The following works for gpt2, llama2 and mistral models.
"""
if "gpt" in model_name:
return "transformer", "h"
return "transformer", "h", "attn", "c_proj"
if "mamba" in model_name:
return "backbone", "layers"
return "model", "layers"
return "backbone", "layers", None, None
return "model", "layers", "attention", "heads"

@abstractmethod
def source_forward_pass(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ fsspec = "2023.9.2"
ipython = "^8.22.2"
ipdb = "^0.13.13"
torchmetrics = "^1.3.1"
einops = "^0.7.0"

[tool.poetry.dev-dependencies]
# Everything below here is alphabetically sorted
Expand Down
Loading