Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Patchscope Logit Lens Token Looping to Script #43

Merged
merged 10 commits into from
Apr 22, 2024
66 changes: 36 additions & 30 deletions obvs/lenses.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,29 +275,27 @@ def visualize(self, kind: str = "top_logits_preds", file_name: str = "") -> Figu

# create NxM list of strings from the top predictions
top_preds = []

# loop over the layer dimension in top_preds, get a list of predictions for
# each position associated with that layer
for i in range(top_pred_idcs.shape[0]):
top_preds.append(self.patchscope.tokenizer.batch_decode(top_pred_idcs[i]))

x_ticks = [
f"{self.patchscope.tokenizer.decode(tok)}" for tok in self.data["substring_tokens"]
]
y_ticks = [
f"{self.patchscope.MODEL_SOURCE}_{self.patchscope.LAYER_SOURCE}{i}"
f"{self.patchscope.source_base_name}_{self.patchscope.source_layer_name}{i}"
for i in self.data["layers"]
]

# create a heatmap with the top logits and predicted tokens

fig = create_heatmap(
x_ticks,
y_ticks,
top_logits,
cell_annotations=top_preds,
title="Top predicted token and its logit",
)

if file_name:
fig.write_html(f'{file_name.replace(".html", "")}.html')
return fig
Expand All @@ -314,47 +312,55 @@ class PatchscopeLogitLens(BaseLogitLens):

The source layer l and position i can vary.

In words: The logit-lens maps the hidden state at position i of layer l of the model M
to the last layer of that same model. It is equal to taking the hidden state and
applying unembed to it.
In words: The logit-lens maps the hidden state at position i of layer l of the model M
to the last layer of that same model. It is equal to taking the hidden state and
applying unembed to it.
"""

def run(self, substring: str, layers: list[int]):
"""Run the logit lens for each layer in layers and each token in substring.

Args:
substring (str): Substring of the prompt for which the top prediction and logits
should be calculated.
layers (List[int]): Indices of Transformer Layers for which the lens should be applied
def __init__(self, model: str, prompt: str, device: str, layers: list[int], substring: str):
"""
substring (str): Substring of the prompt for which the top prediction and logits should be calculated.
layers (list[int]): Indices of Transformer Layers for which the lens should be applied
"""

# get starting position and tokens of substring
super().__init__(model, prompt, device)
start_pos, substring_tokens = self.patchscope.source_position_tokens(substring)

# initialize tensor for logits
self.start_pos = start_pos
self.layers = layers
shaheenahmedc marked this conversation as resolved.
Show resolved Hide resolved
self.substring_tokens = substring_tokens
self.data["logits"] = torch.zeros(
len(layers),
len(substring_tokens),
self.patchscope.tokenizer.vocab_size,
)

def run(self, position: int):
llinauer marked this conversation as resolved.
Show resolved Hide resolved
"""Run the logit lens for each layer in layers, for a specific position in the prompt.

Args:
position (int): Position in the prompt for which the lens should be applied
"""
# get starting position and tokens of substring
assert position < len(self.substring_tokens), "Position out of bounds!"

# loop over each layer and token in substring
for i, layer in enumerate(layers):
for j in range(len(substring_tokens)):
self.patchscope.source.layer = layer
self.patchscope.source.position = start_pos + j
self.patchscope.target.position = start_pos + j
self.patchscope.run()
for i, layer in enumerate(self.layers):
self.patchscope.source.layer = layer
self.patchscope.source.position = self.start_pos + position
self.patchscope.target.position = self.start_pos + position
self.patchscope.run()

self.data["logits"][i, j, :] = self.patchscope.logits()[start_pos + j].to("cpu")
self.data["logits"][i, position, :] = self.patchscope.logits()[
self.start_pos + position
].to("cpu")

# empty CDUA cache to avoid filling of GPU memory
# empty CUDA cache to avoid filling of GPU memory
torch.cuda.empty_cache()

# detach logits, save tokens from substring and layer indices
self.data["logits"] = self.data["logits"].detach()
self.data["substring_tokens"] = substring_tokens
self.data["layers"] = layers
self.data["substring_tokens"] = self.substring_tokens
self.data["layers"] = self.layers


class ClassicLogitLens(BaseLogitLens):
Expand All @@ -370,7 +376,7 @@ def run(self, substring: str, layers: list[int]):
Args:
substring (str): Substring of the prompt for which the top prediction and logits
should be calculated.
layers (List[int]): Indices of Transformer Layers for which the lens should be applied
layers (list[int]): Indices of Transformer Layers for which the lens should be applied
"""

# get starting position and tokens of substring
Expand All @@ -388,8 +394,8 @@ def run(self, substring: str, layers: list[int]):
# with one forward pass, we can get the logits of every position
with self.patchscope.source_model.trace(self.patchscope.source.prompt) as _:
# get the appropriate sub-module and block from source_model
sub_mod = getattr(self.patchscope.source_model, self.patchscope.MODEL_SOURCE)
block = getattr(sub_mod, self.patchscope.LAYER_SOURCE)
sub_mod = getattr(self.patchscope.source_model, self.patchscope.source_base_name)
block = getattr(sub_mod, self.patchscope.source_layer_name)

# get hidden state after specified layer
hidden = block[layer].output[0]
Expand Down
25 changes: 13 additions & 12 deletions obvs/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,20 @@ def create_heatmap(
go.Figure: The heatmap figure.
"""

# x_data and y_data is treated categorical in plotly heatmaps, if the lists contain
# duplicates, these will be removed -> prevent this
x_categories = {val: i for i, val in enumerate(x_data)}
x_numeric = [x_categories[val] for val in x_data]
# Ensure the outer list of values matches the length of y_data
assert len(values) == len(y_data), "Length of values must match length of y_data"
for row in values:
assert len(row) == len(x_data), "Each row in values must match the length of x_data"

y_categories = {val: i for i, val in enumerate(y_data)}
y_numeric = [y_categories[val] for val in y_data]
# Use ordered indexing to accommodate non-unique labels
x_ticks = list(range(len(x_data)))
y_ticks = list(range(len(y_data)))

fig = go.Figure(
data=go.Heatmap(
z=values,
x=x_numeric,
y=y_numeric,
x=x_ticks,
y=y_ticks,
hoverongaps=False,
text=cell_annotations,
texttemplate="%{text}",
Expand All @@ -56,15 +57,15 @@ def create_heatmap(
tickfont=dict(size=16),
titlefont=dict(size=18),
tickangle=-45,
tickvals=list(x_categories.values()),
ticktext=list(x_categories.keys()),
tickvals=x_ticks,
ticktext=x_data,
),
yaxis=dict(
title=y_label,
tickfont=dict(size=16),
titlefont=dict(size=18),
tickvals=list(y_categories.values()),
ticktext=list(y_categories.keys()),
tickvals=y_ticks,
ticktext=y_data,
),
titlefont=dict(size=20),
)
Expand Down
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 9 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ version = "0.0.1"
description = "Making Transformers Obvious"
authors = ["Jamie Coombes <[email protected]>"]
license = "MIT"
packages = [
{ include = "obvs" },
]
packages = [{ include = "obvs" }]

[tool.poetry.dependencies]
python = "^3.10.0"
Expand All @@ -27,6 +25,7 @@ ipython = "^8.22.2"
ipdb = "^0.13.13"
torchmetrics = "^1.3.1"
einops = "^0.7.0"
idna = "^3.7"

[tool.poetry.dev-dependencies]
# Everything below here is alphabetically sorted
Expand Down Expand Up @@ -56,7 +55,7 @@ pytest = "^7.3.1"

[tool.poetry.group.dev.dependencies]
pyinstrument = "^4.6.2"
huggingface-hub = {extras = ["cli"], version = "^0.20.3"}
huggingface-hub = { extras = ["cli"], version = "^0.20.3" }

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down Expand Up @@ -89,25 +88,25 @@ min-similarity-lines = 150
max-statements = 89
max-args = 22
max-branches = 17
disable= [
disable = [
"fixme",
"invalid-name", # disable for now, will fix later in patchscope
"line-too-long", # already handled by black
"invalid-name", # disable for now, will fix later in patchscope
"line-too-long", # already handled by black
"locally-disabled",
"logging-fstring-interpolation",
"missing-class-docstring",
"missing-function-docstring",
"missing-module-docstring",
"no-else-return",
"no-member", # disable for now, will fix later in patchscope_base
"no-member", # disable for now, will fix later in patchscope_base
"protected-access",
"suppressed-message",
"too-few-public-methods",
"too-many-instance-attributes", # already handled by black
"too-many-instance-attributes", # already handled by black
"too-many-public-methods",
"use-dict-literal",
"attribute-defined-outside-init", # disable for now, will fix in lenses.py
]
]
# good-names = []
# disable = []
logging-format-style = "new"
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ filelock==3.13.3 ; python_full_version >= "3.10.0" and python_full_version < "4.
fsspec==2023.9.2 ; python_full_version >= "3.10.0" and python_full_version < "4.0.0"
h11==0.14.0 ; python_full_version >= "3.10.0" and python_full_version < "4.0.0"
huggingface-hub==0.20.3 ; python_full_version >= "3.10.0" and python_full_version < "4.0.0"
idna==3.6 ; python_full_version >= "3.10.0" and python_full_version < "4.0.0"
idna==3.7 ; python_full_version >= "3.10.0" and python_full_version < "4.0.0"
importlib-metadata==7.1.0 ; python_full_version >= "3.10.0" and python_full_version < "4.0.0"
ipdb==0.13.13 ; python_full_version >= "3.10.0" and python_full_version < "4.0.0"
ipykernel==6.29.4 ; python_full_version >= "3.10.0" and python_full_version < "4.0.0"
Expand Down
15 changes: 12 additions & 3 deletions scripts/reproduce_logitlens_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,18 @@
("patchscope_logit_lens", PatchscopeLogitLens),
("classic_logit_lens", ClassicLogitLens),
]:
ll = ll_class(model_name, prompt, "auto")
ll.run(substring, layers)
fig = ll.visualize()
if ll_type == "classic_logit_lens":
ll = ll_class(model_name, prompt, "auto")
ll.run(substring, layers)
fig = ll.visualize()
elif ll_type == "patchscope_logit_lens":
ll = ll_class(model_name, prompt, "auto", layers, substring)
token_ids = ll.substring_tokens
for i in range(len(token_ids)):
ll.run(i)
fig = ll.visualize()
else:
raise ValueError(f"Unknown logit lens type: {ll_type}")
fig.write_html(
f'{model_name.replace("-", "_").replace("/", "_").lower()}_{ll_type}_logits_top_preds.html',
)
Loading