Skip to content

Commit

Permalink
Fix CI (#41)
Browse files Browse the repository at this point in the history
* Fix CI

Ignore docstring errors in flake8

Run black on scripts folder

Format patchscope_base

Black patchscope.py

Ignore more errors

Run black on obvs

Run pre-commit hooks

Update dependencies to pass pip-audit

Fix more flake8 errors

Ignore INP001

Fix lenses

Some more fixes

Fix the last flake8 problems

Black update

Fix autoflake

Fix pylint issues for patchscope

Fix pylint in other places

Fix other pylint errors

Update links from jcoombes/obvs to obvslib/obvs

Fix token identity code

Fix prettier

Ignore attribute-defined-outside-init

Remove TODO comment

Reformat files

Update poetry.lock

More formatting

Fix isinstance

Fix long log line

* Run pre-commit
  • Loading branch information
tvhong authored Apr 11, 2024
1 parent 2d4a407 commit 4c80f93
Show file tree
Hide file tree
Showing 19 changed files with 379 additions and 233 deletions.
14 changes: 6 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
# obvs

[![CI](https://github.com/jcoombes/obvs/actions/workflows/main.yaml/badge.svg)](https://github.com/jcoombes/obvs/actions/workflows/main.yaml)
[![CI](https://github.com/obvslib/obvs/actions/workflows/main.yaml/badge.svg)](https://github.com/obvslib/obvs/actions/workflows/main.yaml)

Making Transformers Obvious

## Project cheatsheet

- **pre-commit:** `pre-commit run --all-files`
- **pytest:** `pytest` or `pytest -s`
- **coverage:** `coverage run -m pytest` or `coverage html`
- **poetry sync:** `poetry install --no-root --sync`
- **updating requirements:** see [docs/updating_requirements.md](docs/updating_requirements.md)


- **pre-commit:** `pre-commit run --all-files`
- **pytest:** `pytest` or `pytest -s`
- **coverage:** `coverage run -m pytest` or `coverage html`
- **poetry sync:** `poetry install --no-root --sync`
- **updating requirements:** see [docs/updating_requirements.md](docs/updating_requirements.md)

## Initial project setup

Expand Down
2 changes: 1 addition & 1 deletion benchmarking/run_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def read_results_file(results_file: str) -> dict:
results = {}

if not Path(results_file).exists():
return
return results

# read results file
with open(results_file) as csv_file:
Expand Down
12 changes: 6 additions & 6 deletions docs/release_process.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ git branch -D release/VERSION

## 2. GitHub Steps

- Copy the **raw markdown** for the release notes in CHANGELOG: [https://github.com/jcoombes/obvs/blob/main/CHANGELOG.md]
- Once you've pushed the tag, you will see it on this page: [https://github.com/jcoombes/obvs/tags]
- Edit the tag and add the release notes
- You will then see the release appear here: [https://github.com/jcoombes/obvs/releases]
- This also sends an email update to anyone on the team who has subscribed containing formatted release notes.
- Once the release is created, edit the release and assign the milestone to the release. Save changes.
- Copy the **raw markdown** for the release notes in CHANGELOG: [https://github.com/obvslib/obvs/blob/main/CHANGELOG.md]
- Once you've pushed the tag, you will see it on this page: [https://github.com/obvslib/obvs/tags]
- Edit the tag and add the release notes
- You will then see the release appear here: [https://github.com/obvslib/obvs/releases]
- This also sends an email update to anyone on the team who has subscribed containing formatted release notes.
- Once the release is created, edit the release and assign the milestone to the release. Save changes.

To finish, copy the release notes and post in any relevant Slack channel or email lists to inform members about the release.
128 changes: 69 additions & 59 deletions obvs/lenses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import List

from pathlib import Path

import torch
import numpy as np
import torch
from plotly.graph_objects import Figure

from obvs.logging import logger
Expand Down Expand Up @@ -112,8 +110,8 @@ def compute_surprisal(self, word: str | None = None):
logger.info(f"Computing surprisal of target tokens: {target} from word {word}")

if hasattr(self, "source_layers") and hasattr(self, "target_layers"):
for i, source_layer in enumerate(self.source_layers):
for j, target_layer in enumerate(self.target_layers):
for i, _source_layer in enumerate(self.source_layers):
for j, _target_layer in enumerate(self.target_layers):
logits = self.outputs[i * len(self.target_layers) + j]
self.surprisal[i, j] = SurprisalMetric.batch(logits, target)
self.precision_at_1[i, j] = PrecisionAtKMetric.batch(logits, target, 1)
Expand Down Expand Up @@ -149,8 +147,8 @@ def run_and_compute(
self.prepare_data_array()

if hasattr(self, "source_layers") and hasattr(self, "target_layers"):
for i, source_layer in enumerate(self.source_layers):
for j, target_layer in enumerate(self.target_layers):
for i, _source_layer in enumerate(self.source_layers):
for j, _target_layer in enumerate(self.target_layers):
self._nextloop(next(self.outputs), word, i, j)
elif hasattr(self, "source_layers"):
for i, output in enumerate(self.outputs):
Expand Down Expand Up @@ -223,14 +221,16 @@ def visualize(self, show: bool = True):
if show:
self.fig.show()
return self


class BaseLogitLens:
""" Parent class for LogitLenses.
Patchscope and classic logit-lens are run differently,
but share the same visualization.
"""
"""Parent class for LogitLenses.
Patchscope and classic logit-lens are run differently,
but share the same visualization.
"""

def __init__(self, model: str, prompt: str, device: str):
""" Constructor. Setup a Patchscope object with Source and Target context.
"""Constructor. Setup a Patchscope object with Source and Target context.
The target context is equal to the source context, apart from the layer.
Args:
Expand All @@ -254,8 +254,8 @@ def __init__(self, model: str, prompt: str, device: str):
self.patchscope = Patchscope(source_context, target_context)
self.data = {}

def visualize(self, kind: str = 'top_logits_preds', file_name: str = '') -> Figure:
""" Visualize the logit lens results in one of the following ways:
def visualize(self, kind: str = "top_logits_preds", file_name: str = "") -> Figure:
"""Visualize the logit lens results in one of the following ways:
top_logits_preds: Heatmap with the top predicted tokens and their logits
Args:
kind (str): The kind of visualization
Expand All @@ -266,13 +266,12 @@ def visualize(self, kind: str = 'top_logits_preds', file_name: str = '') -> Figu
"""

if not self.data:
logger.error('You need to call .run() before .visualize()!')
logger.error("You need to call .run() before .visualize()!")
return Figure()

if kind == 'top_logits_preds':

if kind == "top_logits_preds":
# get the top logits and corresponding tokens for each layer and token position
top_logits, top_pred_idcs = torch.max(self.data['logits'], dim=-1)
top_logits, top_pred_idcs = torch.max(self.data["logits"], dim=-1)

# create NxM list of strings from the top predictions
top_preds = []
Expand All @@ -281,39 +280,47 @@ def visualize(self, kind: str = 'top_logits_preds', file_name: str = '') -> Figu
for i in range(top_pred_idcs.shape[0]):
top_preds.append(self.patchscope.tokenizer.batch_decode(top_pred_idcs[i]))

x_ticks = [f'{self.patchscope.tokenizer.decode(tok)}'
for tok in self.data['substring_tokens']]
y_ticks = [f'{self.patchscope.MODEL_SOURCE}_{self.patchscope.LAYER_SOURCE}{i}'
for i in self.data['layers']]
x_ticks = [
f"{self.patchscope.tokenizer.decode(tok)}" for tok in self.data["substring_tokens"]
]
y_ticks = [
f"{self.patchscope.MODEL_SOURCE}_{self.patchscope.LAYER_SOURCE}{i}"
for i in self.data["layers"]
]

# create a heatmap with the top logits and predicted tokens

fig = create_heatmap(x_ticks, y_ticks, logits, cell_annotations=preds,
title='Top predicted token and its logit')
fig = create_heatmap(
x_ticks,
y_ticks,
top_logits,
cell_annotations=top_preds,
title="Top predicted token and its logit",
)

if file_name:
fig.write_html(f'{file_name.replace(".html", "")}.html')
return fig


class PatchscopeLogitLens(BaseLogitLens):
""" Implementation of logit-lens in patchscope framework.
The logit-lens is defined in the patchscope framework as follows:
S = T (source prompt = target prompt)
M = M* (source model = target model)
l* = L* (target layer = last layer)
i = i* (source position = target position)
f = id (mapping = identity function)
"""Implementation of logit-lens in patchscope framework.
The logit-lens is defined in the patchscope framework as follows:
S = T (source prompt = target prompt)
M = M* (source model = target model)
l* = L* (target layer = last layer)
i = i* (source position = target position)
f = id (mapping = identity function)
The source layer l and position i can vary.
The source layer l and position i can vary.
In words: The logit-lens maps the hidden state at position i of layer l of the model M
to the last layer of that same model. It is equal to taking the hidden state and
applying unembed to it.
"""
In words: The logit-lens maps the hidden state at position i of layer l of the model M
to the last layer of that same model. It is equal to taking the hidden state and
applying unembed to it.
"""

def run(self, substring: str, layers: List[int]):
""" Run the logit lens for each layer in layers and each token in substring.
def run(self, substring: str, layers: list[int]):
"""Run the logit lens for each layer in layers and each token in substring.
Args:
substring (str): Substring of the prompt for which the top prediction and logits
Expand All @@ -325,38 +332,40 @@ def run(self, substring: str, layers: List[int]):
start_pos, substring_tokens = self.patchscope.source_position_tokens(substring)

# initialize tensor for logits
self.data['logits'] = torch.zeros(len(layers), len(substring_tokens),
self.patchscope.tokenizer.vocab_size)
self.data["logits"] = torch.zeros(
len(layers),
len(substring_tokens),
self.patchscope.tokenizer.vocab_size,
)

# loop over each layer and token in substring
for i, layer in enumerate(layers):
for j in range(len(substring_tokens)):

self.patchscope.source.layer = layer
self.patchscope.source.position = start_pos + j
self.patchscope.target.position = start_pos + j
self.patchscope.run()

self.data['logits'][i, j, :] = self.patchscope.logits()[start_pos + j].to('cpu')
self.data["logits"][i, j, :] = self.patchscope.logits()[start_pos + j].to("cpu")

# empty CDUA cache to avoid filling of GPU memory
torch.cuda.empty_cache()

# detach logits, save tokens from substring and layer indices
self.data['logits'] = self.data['logits'].detach()
self.data['substring_tokens'] = substring_tokens
self.data['layers'] = layers
self.data["logits"] = self.data["logits"].detach()
self.data["substring_tokens"] = substring_tokens
self.data["layers"] = layers


class ClassicLogitLens(BaseLogitLens):
""" Implementation of LogitLens in standard fashion.
Run a forward pass on the model and apply the final layer norm and unembed to the output of
a specific layer to get the logits of that layer.
For convenience, use methods from the Patchscope class.
"""Implementation of LogitLens in standard fashion.
Run a forward pass on the model and apply the final layer norm and unembed to the output of
a specific layer to get the logits of that layer.
For convenience, use methods from the Patchscope class.
"""

def run(self, substring: str, layers: List[int]):
""" Run the logit lens for each layer in layers and each token in substring.
def run(self, substring: str, layers: list[int]):
"""Run the logit lens for each layer in layers and each token in substring.
Args:
substring (str): Substring of the prompt for which the top prediction and logits
Expand All @@ -368,15 +377,16 @@ def run(self, substring: str, layers: List[int]):
start_pos, substring_tokens = self.patchscope.source_position_tokens(substring)

# initialize tensor for logits
self.data['logits'] = torch.zeros(len(layers), len(substring_tokens),
self.patchscope.tokenizer.vocab_size)
self.data["logits"] = torch.zeros(
len(layers),
len(substring_tokens),
self.patchscope.tokenizer.vocab_size,
)

# loop over all layers
for i, layer in enumerate(layers):

# with one forward pass, we can get the logits of every position
with self.patchscope.source_model.trace(self.patchscope.source.prompt) as _:

# get the appropriate sub-module and block from source_model
sub_mod = getattr(self.patchscope.source_model, self.patchscope.MODEL_SOURCE)
block = getattr(sub_mod, self.patchscope.LAYER_SOURCE)
Expand All @@ -390,12 +400,12 @@ def run(self, substring: str, layers: List[int]):

# loop over all tokens in substring and get the corresponding logits
for j in range(len(substring_tokens)):
self.data['logits'][i, j, :] = logits[0, start_pos + j, :].to('cpu')
self.data["logits"][i, j, :] = logits[0, start_pos + j, :].to("cpu")

# empty CDUA cache to avoid filling of GPU memory
torch.cuda.empty_cache()

# detach logits, save tokens from substring and layer indices
self.data['logits'] = self.data['logits'].detach()
self.data['substring_tokens'] = substring_tokens
self.data['layers'] = layers
self.data["logits"] = self.data["logits"].detach()
self.data["substring_tokens"] = substring_tokens
self.data["layers"] = layers
6 changes: 3 additions & 3 deletions obvs/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def set_tqdm_logging(exclude_loggers=None):
# Get all existing loggers (including the root) and replace their handlers.
loggers = [logging.root] + list(logging.root.manager.loggerDict.values())

for logger in loggers:
for a_logger in loggers:
if (
isinstance(logger, logging.Logger) and logger.name not in exclude_loggers
isinstance(a_logger, logging.Logger) and a_logger.name not in exclude_loggers
): # Exclude specified loggers
logger.handlers = [tqdm_handler]
a_logger.handlers = [tqdm_handler]


# Now exclude your file logger by name when calling set_tqdm_logging
Expand Down
2 changes: 2 additions & 0 deletions obvs/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, topk=10, dist_sync_on_step=False, batch_size=None) -> None:
self.add_state("correct", default=torch.zeros(batch_size), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

# pylint: disable=arguments-differ
def update(self, logits, true_token_index) -> None:
batch_size = logits.shape[0]
self.correct[:batch_size] += self.batch(logits, true_token_index, self.topk)
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(self, dist_sync_on_step=False, batch_size=None) -> None:
)
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

# pylint: disable=arguments-differ
def update(self, logits, true_token_index) -> None:
batch_size = logits.shape[0]
self.surprisal[:batch_size] += self.batch(logits, true_token_index)
Expand Down
Loading

0 comments on commit 4c80f93

Please sign in to comment.