diff --git a/README.md b/README.md index 3f274a5..990dcb8 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,16 @@ # obvs -[![CI](https://github.com/jcoombes/obvs/actions/workflows/main.yaml/badge.svg)](https://github.com/jcoombes/obvs/actions/workflows/main.yaml) +[![CI](https://github.com/obvslib/obvs/actions/workflows/main.yaml/badge.svg)](https://github.com/obvslib/obvs/actions/workflows/main.yaml) Making Transformers Obvious ## Project cheatsheet - - **pre-commit:** `pre-commit run --all-files` - - **pytest:** `pytest` or `pytest -s` - - **coverage:** `coverage run -m pytest` or `coverage html` - - **poetry sync:** `poetry install --no-root --sync` - - **updating requirements:** see [docs/updating_requirements.md](docs/updating_requirements.md) - - +- **pre-commit:** `pre-commit run --all-files` +- **pytest:** `pytest` or `pytest -s` +- **coverage:** `coverage run -m pytest` or `coverage html` +- **poetry sync:** `poetry install --no-root --sync` +- **updating requirements:** see [docs/updating_requirements.md](docs/updating_requirements.md) ## Initial project setup diff --git a/benchmarking/run_benchmarks.py b/benchmarking/run_benchmarks.py index ba4e0ee..a52ad9f 100644 --- a/benchmarking/run_benchmarks.py +++ b/benchmarking/run_benchmarks.py @@ -73,7 +73,7 @@ def read_results_file(results_file: str) -> dict: results = {} if not Path(results_file).exists(): - return + return results # read results file with open(results_file) as csv_file: diff --git a/docs/release_process.md b/docs/release_process.md index b45490b..5842f56 100644 --- a/docs/release_process.md +++ b/docs/release_process.md @@ -77,11 +77,11 @@ git branch -D release/VERSION ## 2. GitHub Steps -- Copy the **raw markdown** for the release notes in CHANGELOG: [https://github.com/jcoombes/obvs/blob/main/CHANGELOG.md] -- Once you've pushed the tag, you will see it on this page: [https://github.com/jcoombes/obvs/tags] -- Edit the tag and add the release notes -- You will then see the release appear here: [https://github.com/jcoombes/obvs/releases] -- This also sends an email update to anyone on the team who has subscribed containing formatted release notes. -- Once the release is created, edit the release and assign the milestone to the release. Save changes. +- Copy the **raw markdown** for the release notes in CHANGELOG: [https://github.com/obvslib/obvs/blob/main/CHANGELOG.md] +- Once you've pushed the tag, you will see it on this page: [https://github.com/obvslib/obvs/tags] +- Edit the tag and add the release notes +- You will then see the release appear here: [https://github.com/obvslib/obvs/releases] +- This also sends an email update to anyone on the team who has subscribed containing formatted release notes. +- Once the release is created, edit the release and assign the milestone to the release. Save changes. To finish, copy the release notes and post in any relevant Slack channel or email lists to inform members about the release. diff --git a/obvs/lenses.py b/obvs/lenses.py index bcba308..671f587 100644 --- a/obvs/lenses.py +++ b/obvs/lenses.py @@ -7,12 +7,10 @@ from __future__ import annotations from collections.abc import Sequence -from typing import List - from pathlib import Path -import torch import numpy as np +import torch from plotly.graph_objects import Figure from obvs.logging import logger @@ -112,8 +110,8 @@ def compute_surprisal(self, word: str | None = None): logger.info(f"Computing surprisal of target tokens: {target} from word {word}") if hasattr(self, "source_layers") and hasattr(self, "target_layers"): - for i, source_layer in enumerate(self.source_layers): - for j, target_layer in enumerate(self.target_layers): + for i, _source_layer in enumerate(self.source_layers): + for j, _target_layer in enumerate(self.target_layers): logits = self.outputs[i * len(self.target_layers) + j] self.surprisal[i, j] = SurprisalMetric.batch(logits, target) self.precision_at_1[i, j] = PrecisionAtKMetric.batch(logits, target, 1) @@ -149,8 +147,8 @@ def run_and_compute( self.prepare_data_array() if hasattr(self, "source_layers") and hasattr(self, "target_layers"): - for i, source_layer in enumerate(self.source_layers): - for j, target_layer in enumerate(self.target_layers): + for i, _source_layer in enumerate(self.source_layers): + for j, _target_layer in enumerate(self.target_layers): self._nextloop(next(self.outputs), word, i, j) elif hasattr(self, "source_layers"): for i, output in enumerate(self.outputs): @@ -223,14 +221,16 @@ def visualize(self, show: bool = True): if show: self.fig.show() return self + + class BaseLogitLens: - """ Parent class for LogitLenses. - Patchscope and classic logit-lens are run differently, - but share the same visualization. - """ + """Parent class for LogitLenses. + Patchscope and classic logit-lens are run differently, + but share the same visualization. + """ def __init__(self, model: str, prompt: str, device: str): - """ Constructor. Setup a Patchscope object with Source and Target context. + """Constructor. Setup a Patchscope object with Source and Target context. The target context is equal to the source context, apart from the layer. Args: @@ -254,8 +254,8 @@ def __init__(self, model: str, prompt: str, device: str): self.patchscope = Patchscope(source_context, target_context) self.data = {} - def visualize(self, kind: str = 'top_logits_preds', file_name: str = '') -> Figure: - """ Visualize the logit lens results in one of the following ways: + def visualize(self, kind: str = "top_logits_preds", file_name: str = "") -> Figure: + """Visualize the logit lens results in one of the following ways: top_logits_preds: Heatmap with the top predicted tokens and their logits Args: kind (str): The kind of visualization @@ -266,13 +266,12 @@ def visualize(self, kind: str = 'top_logits_preds', file_name: str = '') -> Figu """ if not self.data: - logger.error('You need to call .run() before .visualize()!') + logger.error("You need to call .run() before .visualize()!") return Figure() - if kind == 'top_logits_preds': - + if kind == "top_logits_preds": # get the top logits and corresponding tokens for each layer and token position - top_logits, top_pred_idcs = torch.max(self.data['logits'], dim=-1) + top_logits, top_pred_idcs = torch.max(self.data["logits"], dim=-1) # create NxM list of strings from the top predictions top_preds = [] @@ -281,15 +280,23 @@ def visualize(self, kind: str = 'top_logits_preds', file_name: str = '') -> Figu for i in range(top_pred_idcs.shape[0]): top_preds.append(self.patchscope.tokenizer.batch_decode(top_pred_idcs[i])) - x_ticks = [f'{self.patchscope.tokenizer.decode(tok)}' - for tok in self.data['substring_tokens']] - y_ticks = [f'{self.patchscope.MODEL_SOURCE}_{self.patchscope.LAYER_SOURCE}{i}' - for i in self.data['layers']] + x_ticks = [ + f"{self.patchscope.tokenizer.decode(tok)}" for tok in self.data["substring_tokens"] + ] + y_ticks = [ + f"{self.patchscope.MODEL_SOURCE}_{self.patchscope.LAYER_SOURCE}{i}" + for i in self.data["layers"] + ] # create a heatmap with the top logits and predicted tokens - fig = create_heatmap(x_ticks, y_ticks, logits, cell_annotations=preds, - title='Top predicted token and its logit') + fig = create_heatmap( + x_ticks, + y_ticks, + top_logits, + cell_annotations=top_preds, + title="Top predicted token and its logit", + ) if file_name: fig.write_html(f'{file_name.replace(".html", "")}.html') @@ -297,23 +304,23 @@ def visualize(self, kind: str = 'top_logits_preds', file_name: str = '') -> Figu class PatchscopeLogitLens(BaseLogitLens): - """ Implementation of logit-lens in patchscope framework. - The logit-lens is defined in the patchscope framework as follows: - S = T (source prompt = target prompt) - M = M* (source model = target model) - l* = L* (target layer = last layer) - i = i* (source position = target position) - f = id (mapping = identity function) + """Implementation of logit-lens in patchscope framework. + The logit-lens is defined in the patchscope framework as follows: + S = T (source prompt = target prompt) + M = M* (source model = target model) + l* = L* (target layer = last layer) + i = i* (source position = target position) + f = id (mapping = identity function) - The source layer l and position i can vary. + The source layer l and position i can vary. - In words: The logit-lens maps the hidden state at position i of layer l of the model M - to the last layer of that same model. It is equal to taking the hidden state and - applying unembed to it. - """ + In words: The logit-lens maps the hidden state at position i of layer l of the model M + to the last layer of that same model. It is equal to taking the hidden state and + applying unembed to it. + """ - def run(self, substring: str, layers: List[int]): - """ Run the logit lens for each layer in layers and each token in substring. + def run(self, substring: str, layers: list[int]): + """Run the logit lens for each layer in layers and each token in substring. Args: substring (str): Substring of the prompt for which the top prediction and logits @@ -325,38 +332,40 @@ def run(self, substring: str, layers: List[int]): start_pos, substring_tokens = self.patchscope.source_position_tokens(substring) # initialize tensor for logits - self.data['logits'] = torch.zeros(len(layers), len(substring_tokens), - self.patchscope.tokenizer.vocab_size) + self.data["logits"] = torch.zeros( + len(layers), + len(substring_tokens), + self.patchscope.tokenizer.vocab_size, + ) # loop over each layer and token in substring for i, layer in enumerate(layers): for j in range(len(substring_tokens)): - self.patchscope.source.layer = layer self.patchscope.source.position = start_pos + j self.patchscope.target.position = start_pos + j self.patchscope.run() - self.data['logits'][i, j, :] = self.patchscope.logits()[start_pos + j].to('cpu') + self.data["logits"][i, j, :] = self.patchscope.logits()[start_pos + j].to("cpu") # empty CDUA cache to avoid filling of GPU memory torch.cuda.empty_cache() # detach logits, save tokens from substring and layer indices - self.data['logits'] = self.data['logits'].detach() - self.data['substring_tokens'] = substring_tokens - self.data['layers'] = layers + self.data["logits"] = self.data["logits"].detach() + self.data["substring_tokens"] = substring_tokens + self.data["layers"] = layers class ClassicLogitLens(BaseLogitLens): - """ Implementation of LogitLens in standard fashion. - Run a forward pass on the model and apply the final layer norm and unembed to the output of - a specific layer to get the logits of that layer. - For convenience, use methods from the Patchscope class. + """Implementation of LogitLens in standard fashion. + Run a forward pass on the model and apply the final layer norm and unembed to the output of + a specific layer to get the logits of that layer. + For convenience, use methods from the Patchscope class. """ - def run(self, substring: str, layers: List[int]): - """ Run the logit lens for each layer in layers and each token in substring. + def run(self, substring: str, layers: list[int]): + """Run the logit lens for each layer in layers and each token in substring. Args: substring (str): Substring of the prompt for which the top prediction and logits @@ -368,15 +377,16 @@ def run(self, substring: str, layers: List[int]): start_pos, substring_tokens = self.patchscope.source_position_tokens(substring) # initialize tensor for logits - self.data['logits'] = torch.zeros(len(layers), len(substring_tokens), - self.patchscope.tokenizer.vocab_size) + self.data["logits"] = torch.zeros( + len(layers), + len(substring_tokens), + self.patchscope.tokenizer.vocab_size, + ) # loop over all layers for i, layer in enumerate(layers): - # with one forward pass, we can get the logits of every position with self.patchscope.source_model.trace(self.patchscope.source.prompt) as _: - # get the appropriate sub-module and block from source_model sub_mod = getattr(self.patchscope.source_model, self.patchscope.MODEL_SOURCE) block = getattr(sub_mod, self.patchscope.LAYER_SOURCE) @@ -390,12 +400,12 @@ def run(self, substring: str, layers: List[int]): # loop over all tokens in substring and get the corresponding logits for j in range(len(substring_tokens)): - self.data['logits'][i, j, :] = logits[0, start_pos + j, :].to('cpu') + self.data["logits"][i, j, :] = logits[0, start_pos + j, :].to("cpu") # empty CDUA cache to avoid filling of GPU memory torch.cuda.empty_cache() # detach logits, save tokens from substring and layer indices - self.data['logits'] = self.data['logits'].detach() - self.data['substring_tokens'] = substring_tokens - self.data['layers'] = layers + self.data["logits"] = self.data["logits"].detach() + self.data["substring_tokens"] = substring_tokens + self.data["layers"] = layers diff --git a/obvs/logging.py b/obvs/logging.py index 9baea3f..33eb220 100644 --- a/obvs/logging.py +++ b/obvs/logging.py @@ -22,11 +22,11 @@ def set_tqdm_logging(exclude_loggers=None): # Get all existing loggers (including the root) and replace their handlers. loggers = [logging.root] + list(logging.root.manager.loggerDict.values()) - for logger in loggers: + for a_logger in loggers: if ( - isinstance(logger, logging.Logger) and logger.name not in exclude_loggers + isinstance(a_logger, logging.Logger) and a_logger.name not in exclude_loggers ): # Exclude specified loggers - logger.handlers = [tqdm_handler] + a_logger.handlers = [tqdm_handler] # Now exclude your file logger by name when calling set_tqdm_logging diff --git a/obvs/metrics.py b/obvs/metrics.py index b232852..02c05cf 100644 --- a/obvs/metrics.py +++ b/obvs/metrics.py @@ -18,6 +18,7 @@ def __init__(self, topk=10, dist_sync_on_step=False, batch_size=None) -> None: self.add_state("correct", default=torch.zeros(batch_size), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + # pylint: disable=arguments-differ def update(self, logits, true_token_index) -> None: batch_size = logits.shape[0] self.correct[:batch_size] += self.batch(logits, true_token_index, self.topk) @@ -73,6 +74,7 @@ def __init__(self, dist_sync_on_step=False, batch_size=None) -> None: ) self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + # pylint: disable=arguments-differ def update(self, logits, true_token_index) -> None: batch_size = logits.shape[0] self.surprisal[:batch_size] += self.batch(logits, true_token_index) diff --git a/obvs/patchscope.py b/obvs/patchscope.py index 1cb76ae..780ac4f 100644 --- a/obvs/patchscope.py +++ b/obvs/patchscope.py @@ -49,6 +49,7 @@ class SourceContext: """ Source context for the patchscope """ + _prompt: str | torch.Tensor = field(init=False, repr=False, default="<|endoftext|>") _text_prompt: str = field(init=False, repr=False) _soft_prompt: torch.Tensor | None = field(init=False, repr=False) @@ -84,7 +85,9 @@ def prompt(self, value: str | torch.Tensor | None): value = torch.unsqueeze(value, 0) if value.dim() != 3: - raise ValueError(f"Soft prompt must have shape [tokens_len, d_model] or [batch, tokens_len, d_model]. But prompt.shape is {value.shape}") + raise ValueError( + f"Soft prompt must have shape [tokens_len, d_model] or [batch, tokens_len, d_model]. But prompt.shape is {value.shape}", + ) self._text_prompt = " ".join("_" * value.shape[1]) self._soft_prompt = value @@ -148,6 +151,7 @@ class ModelLoader: @staticmethod def load(model_name: str, device: str) -> LanguageModel: if "mamba" in model_name: + # pylint: disable=import-outside-toplevel # We import here because MambaInterp depends on some GPU libs that might not be installed. from nnsight.models.Mamba import MambaInterp @@ -178,10 +182,23 @@ def __init__(self, source: SourceContext, target: TargetContext) -> None: self.tokenizer = self.source_model.tokenizer - self.source_base_name, self.source_layer_name, self.source_attn_name, self.source_head_name = \ - self.get_model_specifics(self.source.model_name) - self.target_base_name, self.target_layer_name, self.target_attn_name, self.target_head_name = \ - self.get_model_specifics(self.target.model_name) + ( + self.source_base_name, + self.source_layer_name, + self.source_attn_name, + self.source_head_name, + ) = self.get_model_specifics(self.source.model_name) + ( + self.target_base_name, + self.target_layer_name, + self.target_attn_name, + self.target_head_name, + ) = self.get_model_specifics(self.target.model_name) + + self._source_hidden_state = None + self.source_output = None + + self._mapped_hidden_state = None self._target_outputs: list[torch.Tensor] = [] @@ -222,9 +239,12 @@ def manipulate_source(self) -> torch.Tensor: head_act = getattr(attn, self.source_head_name).input[0][0] # need to reshape the output into the specific heads - head_act = einops.rearrange(head_act, - 'batch pos (n_head d_head) -> batch pos n_head d_head', - n_head=attn.num_heads, d_head=attn.head_dim) + head_act = einops.rearrange( + head_act, + "batch pos (n_head d_head) -> batch pos n_head d_head", + n_head=attn.num_heads, + d_head=attn.head_dim, + ) return head_act[:, self._source_position, self.source.head, :] return layer.output[0][:, self._source_position, :] @@ -257,7 +277,6 @@ def target_forward_pass(self) -> None: self.manipulate_target() def manipulate_target(self) -> None: - # get the specified layer layer = getattr(getattr(self.target_model, self.target_base_name), self.target_layer_name)[ self.target.layer @@ -271,18 +290,26 @@ def manipulate_target(self) -> None: # need to reshape the output of head into the specific heads split_head_act = einops.rearrange( - concat_head_act, 'batch pos (n_head d_head) -> batch pos n_head d_head', - n_head=attn.num_heads, d_head=attn.head_dim + concat_head_act, + "batch pos (n_head d_head) -> batch pos n_head d_head", + n_head=attn.num_heads, + d_head=attn.head_dim, ) # check if the dimensions of the mapped_hidden_state and target head activations match target_act = split_head_act[:, self._target_position, self.target.head, :] if self._mapped_hidden_state.shape != target_act.shape: raise ValueError( - f'Cannot set activation of head {self.target.head} in target model with shape' - f' {list(target_act.shape)} to patched activation of source model with shape' - f' {list(self._mapped_hidden_state.shape)}!') - split_head_act[:, self._target_position, self.target.head, :] = self._mapped_hidden_state + f"Cannot set activation of head {self.target.head} in target model with shape" + f" {list(target_act.shape)} to patched activation of source model with shape" + f" {list(self._mapped_hidden_state.shape)}!", + ) + split_head_act[ + :, + self._target_position, + self.target.head, + :, + ] = self._mapped_hidden_state else: layer.output[0][:, self._target_position, :] = self._mapped_hidden_state @@ -291,21 +318,26 @@ def manipulate_target(self) -> None: self._target_outputs.append(self.target_model.lm_head.next().output[0].save()) def check_patchscope_setup(self) -> bool: - """ Check if patchscope is correctly set-up before running """ + """Check if patchscope is correctly set-up before running""" # head can be int or None, patchscope run is only possible if they have the same type - # TODO: Find out how to do it PEP8 compliant - if type(self.source.head) != type(self.target.head): - logger.error('Cannot run patchscope with source head attribute: %s and target' - ' head attribute: %s. Both need to be of the same type.', self.source.head, - self.target.head) + if not isinstance(self.source.head, type(self.target.head)): + logger.error( + f"Cannot run patchscope with source head attribute: {self.source.head} and target head attribute: {self.target.head}. Both need to be of the same type.", + ) return False # currently, accessing single head activations is only supported for GPT2LMHead models - if (self.source.head is not None and 'gpt2' not in self.source.model_name or - self.target.head is not None and 'gpt2' not in self.target.model_name): - raise NotImplementedError('Accessing single head activations is currently only' - ' implemented for GPT2-style models') + if ( + self.source.head is not None + and "gpt2" not in self.source.model_name + or self.target.head is not None + and "gpt2" not in self.target.model_name + ): + raise NotImplementedError( + "Accessing single head activations is currently only" + " implemented for GPT2-style models", + ) return True @@ -316,7 +348,7 @@ def run(self) -> None: # check before running if not self.check_patchscope_setup(): - raise ValueError('Cannot run patchscope with the provided arguments') + raise ValueError("Cannot run patchscope with the provided arguments") self.clear() self.source_forward_pass() diff --git a/obvs/patchscope_base.py b/obvs/patchscope_base.py index 4ab7513..5ca8b4c 100644 --- a/obvs/patchscope_base.py +++ b/obvs/patchscope_base.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Sequence +from collections.abc import Sequence import torch @@ -40,15 +40,19 @@ def run(self) -> None: @property def _source_position(self) -> Sequence[int]: - return (self.source.position - if self.source.position is not None - else range(len(self.source_token_ids))) + return ( + self.source.position + if self.source.position is not None + else range(len(self.source_token_ids)) + ) @property def _target_position(self) -> Sequence[int]: - return (self.target.position - if self.target.position is not None - else range(len(self.target_token_ids))) + return ( + self.target.position + if self.target.position is not None + else range(len(self.target_token_ids)) + ) @property def source_token_ids(self) -> list[int]: @@ -78,20 +82,20 @@ def target_tokens(self) -> list[str]: """ return [self.tokenizer.decode(token) for token in self.target_token_ids] - def top_k_tokens(self, k: int=10) -> list[str]: + def top_k_tokens(self, k: int = 10) -> list[str]: """ Return the top k tokens from the target model """ token_ids = self._target_outputs[0].value[self.target.position, :].topk(k).indices.tolist() return [self.tokenizer.decode(token_id) for token_id in token_ids] - def top_k_logits(self, k: int=10) -> list[int]: + def top_k_logits(self, k: int = 10) -> list[int]: """ Return the top k logits from the target model """ return self._target_outputs[0].value[self.target.position, :].topk(k).values.tolist() - def top_k_probs(self, k: int=10) -> list[float]: + def top_k_probs(self, k: int = 10) -> list[float]: """ Return the top k probabilities from the target model """ @@ -133,7 +137,8 @@ def llama_output(self) -> list[str]: def full_output_tokens(self) -> list[str]: """ Return the generated output from the target model - This is a bit hacky. Its not super well supported. I have to concatenate all the inputs and add the input tokens to them. + This is a bit hacky. Its not super well supported. I have to concatenate + all the inputs and add the input tokens to them. """ token_ids = self._output_token_ids() @@ -154,14 +159,16 @@ def find_in_source(self, substring: str) -> int: """ Find the position of the substring tokens in the source prompt - Note: only works if substring's tokenization happens to match that of the source prompt's tokenization + Note: only works if substring's tokenization happens to match that of + the source prompt's tokenization """ position, _ = self.source_position_tokens(substring) return position def source_position_tokens(self, substring: str) -> tuple[int, list[int]]: """ - Find the position of a substring in the source prompt, and return the substring tokenized + Find the position of a substring in the source prompt, and return the + substring tokenized NB: The try: except block handles the difference between gpt2 and llama tokenization. Perhaps this can be better dealt with a seperate tokenizer @@ -170,7 +177,9 @@ class that handles the differences between the tokenizers. There are a the best out of your model. """ if substring not in self.source.text_prompt: - raise ValueError(f"Substring {substring} could not be found in {self.source.text_prompt}") + raise ValueError( + f"Substring {substring} could not be found in {self.source.text_prompt}", + ) try: token_ids = self.tokenizer.encode(substring, add_special_tokens=False) @@ -183,14 +192,16 @@ def find_in_target(self, substring: str) -> int: """ Find the position of the substring tokens in the target prompt - Note: only works if substring's tokenization happens to match that of the target prompt's tokenization + Note: only works if substring's tokenization happens to match that of + the target prompt's tokenization """ position, _ = self.target_position_tokens(substring) return position def target_position_tokens(self, substring) -> tuple[int, list[int]]: """ - Find the position of a substring in the target prompt, and return the substring tokenized + Find the position of a substring in the target prompt, and return the + substring tokenized NB: The try: except block handles the difference between gpt2 and llama tokenization. Perhaps this can be better dealt with a seperate tokenizer @@ -199,7 +210,9 @@ class that handles the differences between the tokenizers. There are a the best out of your model. """ if substring not in self.target.text_prompt: - raise ValueError(f"Substring {substring} could not be found in {self.target.text_prompt}") + raise ValueError( + f"Substring {substring} could not be found in {self.target.text_prompt}", + ) try: token_ids = self.tokenizer.encode(substring, add_special_tokens=False) @@ -214,11 +227,15 @@ def n_layers(self) -> int: @property def n_layers_source(self) -> int: - return len(getattr(getattr(self.source_model, self.MODEL_TARGET), self.LAYER_TARGET)) + return len( + getattr(getattr(self.source_model, self.target_base_name), self.target_layer_name), + ) @property def n_layers_target(self) -> int: - return len(getattr(getattr(self.target_model, self.MODEL_TARGET), self.LAYER_TARGET)) + return len( + getattr(getattr(self.target_model, self.target_base_name), self.target_layer_name), + ) def compute_precision_at_1(self, estimated_probs: torch.Tensor, true_token_index): """ @@ -230,7 +247,8 @@ def compute_precision_at_1(self, estimated_probs: torch.Tensor, true_token_index Returns: - precision_at_1: Precision@1 metric result. - This is the evaluation method of the token identity from patchscopes: https://arxiv.org/abs/2401.06102 + This is the evaluation method of the token identity from patchscopes: + https://arxiv.org/abs/2401.06102 Its used for running an evaluation over large datasets. """ predicted_token_index = torch.argmax(estimated_probs) @@ -240,10 +258,14 @@ def compute_precision_at_1(self, estimated_probs: torch.Tensor, true_token_index def compute_surprisal(self, estimated_probs: torch.Tensor, true_token_index): """ Compute Surprisal metric. From the outputs of the target (patched) model - (estimated_probs) against the output of the source model, aka the 'true' token. + (estimated_probs) against the output of the source model, aka the 'true' + token. + Args: - - estimated_probs: The estimated probabilities for each token as a torch.Tensor. + - estimated_probs: The estimated probabilities for each token as a + torch.Tensor. - true_token_index: The index of the true token in the vocabulary. + Returns: - surprisal: Surprisal metric result. """ diff --git a/obvs/vis.py b/obvs/vis.py index f9f38c2..39b93a4 100644 --- a/obvs/vis.py +++ b/obvs/vis.py @@ -1,12 +1,17 @@ from __future__ import annotations import plotly.graph_objects as go -from typing import List -def create_heatmap(x_data: List[str | int | float], y_data: List[str | int | float], - values: List[float], title: str = '', cell_annotations: List[str] = None, - x_label: str = '', y_label: str = '') -> go.Figure: +def create_heatmap( + x_data: list[str | int | float], + y_data: list[str | int | float], + values: list[float], + title: str = "", + cell_annotations: list[str] = None, + x_label: str = "", + y_label: str = "", +) -> go.Figure: """ Create a heatmap with annotated cells. Set the x_ticks, y_ticks and title accordingly. @@ -31,23 +36,37 @@ def create_heatmap(x_data: List[str | int | float], y_data: List[str | int | flo y_categories = {val: i for i, val in enumerate(y_data)} y_numeric = [y_categories[val] for val in y_data] - fig = go.Figure(data=go.Heatmap( - z=values, - x=x_numeric, - y=y_numeric, - hoverongaps=False, - text=cell_annotations, - texttemplate="%{text}", - textfont={"size": 20}, - colorscale='Viridis')) + fig = go.Figure( + data=go.Heatmap( + z=values, + x=x_numeric, + y=y_numeric, + hoverongaps=False, + text=cell_annotations, + texttemplate="%{text}", + textfont={"size": 20}, + colorscale="Viridis", + ), + ) fig.update_layout( title=title, - xaxis=dict(title=x_label, tickfont=dict(size=16), titlefont=dict(size=18), tickangle=-45, - tickvals=list(x_categories.values()), ticktext=list(x_categories.keys())), - yaxis=dict(title=y_label, tickfont=dict(size=16), titlefont=dict(size=18), - tickvals=list(y_categories.values()), ticktext=list(y_categories.keys())), - titlefont=dict(size=20) + xaxis=dict( + title=x_label, + tickfont=dict(size=16), + titlefont=dict(size=18), + tickangle=-45, + tickvals=list(x_categories.values()), + ticktext=list(x_categories.keys()), + ), + yaxis=dict( + title=y_label, + tickfont=dict(size=16), + titlefont=dict(size=18), + tickvals=list(y_categories.values()), + ticktext=list(y_categories.keys()), + ), + titlefont=dict(size=20), ) return fig @@ -111,5 +130,3 @@ def plot_surprisal(layers, values, std=None, title="Surprisal") -> go.Figure: ) return fig - - diff --git a/poetry.lock b/poetry.lock index abddec8..ec156fb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -156,33 +156,33 @@ files = [ [[package]] name = "black" -version = "23.12.1" +version = "24.3.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" files = [ - {file = "black-23.12.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0aaf6041986767a5e0ce663c7a2f0e9eaf21e6ff87a5f95cbf3675bfd4c41d2"}, - {file = "black-23.12.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c88b3711d12905b74206227109272673edce0cb29f27e1385f33b0163c414bba"}, - {file = "black-23.12.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a920b569dc6b3472513ba6ddea21f440d4b4c699494d2e972a1753cdc25df7b0"}, - {file = "black-23.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:3fa4be75ef2a6b96ea8d92b1587dd8cb3a35c7e3d51f0738ced0781c3aa3a5a3"}, - {file = "black-23.12.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8d4df77958a622f9b5a4c96edb4b8c0034f8434032ab11077ec6c56ae9f384ba"}, - {file = "black-23.12.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:602cfb1196dc692424c70b6507593a2b29aac0547c1be9a1d1365f0d964c353b"}, - {file = "black-23.12.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c4352800f14be5b4864016882cdba10755bd50805c95f728011bcb47a4afd59"}, - {file = "black-23.12.1-cp311-cp311-win_amd64.whl", hash = "sha256:0808494f2b2df923ffc5723ed3c7b096bd76341f6213989759287611e9837d50"}, - {file = "black-23.12.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:25e57fd232a6d6ff3f4478a6fd0580838e47c93c83eaf1ccc92d4faf27112c4e"}, - {file = "black-23.12.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2d9e13db441c509a3763a7a3d9a49ccc1b4e974a47be4e08ade2a228876500ec"}, - {file = "black-23.12.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1bd9c210f8b109b1762ec9fd36592fdd528485aadb3f5849b2740ef17e674e"}, - {file = "black-23.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:ae76c22bde5cbb6bfd211ec343ded2163bba7883c7bc77f6b756a1049436fbb9"}, - {file = "black-23.12.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1fa88a0f74e50e4487477bc0bb900c6781dbddfdfa32691e780bf854c3b4a47f"}, - {file = "black-23.12.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a4d6a9668e45ad99d2f8ec70d5c8c04ef4f32f648ef39048d010b0689832ec6d"}, - {file = "black-23.12.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b18fb2ae6c4bb63eebe5be6bd869ba2f14fd0259bda7d18a46b764d8fb86298a"}, - {file = "black-23.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:c04b6d9d20e9c13f43eee8ea87d44156b8505ca8a3c878773f68b4e4812a421e"}, - {file = "black-23.12.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3e1b38b3135fd4c025c28c55ddfc236b05af657828a8a6abe5deec419a0b7055"}, - {file = "black-23.12.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4f0031eaa7b921db76decd73636ef3a12c942ed367d8c3841a0739412b260a54"}, - {file = "black-23.12.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97e56155c6b737854e60a9ab1c598ff2533d57e7506d97af5481141671abf3ea"}, - {file = "black-23.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:dd15245c8b68fe2b6bd0f32c1556509d11bb33aec9b5d0866dd8e2ed3dba09c2"}, - {file = "black-23.12.1-py3-none-any.whl", hash = "sha256:78baad24af0f033958cad29731e27363183e140962595def56423e626f4bee3e"}, - {file = "black-23.12.1.tar.gz", hash = "sha256:4ce3ef14ebe8d9509188014d96af1c456a910d5b5cbf434a09fef7e024b3d0d5"}, + {file = "black-24.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7d5e026f8da0322b5662fa7a8e752b3fa2dac1c1cbc213c3d7ff9bdd0ab12395"}, + {file = "black-24.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9f50ea1132e2189d8dff0115ab75b65590a3e97de1e143795adb4ce317934995"}, + {file = "black-24.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2af80566f43c85f5797365077fb64a393861a3730bd110971ab7a0c94e873e7"}, + {file = "black-24.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:4be5bb28e090456adfc1255e03967fb67ca846a03be7aadf6249096100ee32d0"}, + {file = "black-24.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4f1373a7808a8f135b774039f61d59e4be7eb56b2513d3d2f02a8b9365b8a8a9"}, + {file = "black-24.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:aadf7a02d947936ee418777e0247ea114f78aff0d0959461057cae8a04f20597"}, + {file = "black-24.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c02e4ea2ae09d16314d30912a58ada9a5c4fdfedf9512d23326128ac08ac3d"}, + {file = "black-24.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:bf21b7b230718a5f08bd32d5e4f1db7fc8788345c8aea1d155fc17852b3410f5"}, + {file = "black-24.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:2818cf72dfd5d289e48f37ccfa08b460bf469e67fb7c4abb07edc2e9f16fb63f"}, + {file = "black-24.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4acf672def7eb1725f41f38bf6bf425c8237248bb0804faa3965c036f7672d11"}, + {file = "black-24.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7ed6668cbbfcd231fa0dc1b137d3e40c04c7f786e626b405c62bcd5db5857e4"}, + {file = "black-24.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:56f52cfbd3dabe2798d76dbdd299faa046a901041faf2cf33288bc4e6dae57b5"}, + {file = "black-24.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:79dcf34b33e38ed1b17434693763301d7ccbd1c5860674a8f871bd15139e7837"}, + {file = "black-24.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e19cb1c6365fd6dc38a6eae2dcb691d7d83935c10215aef8e6c38edee3f77abd"}, + {file = "black-24.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65b76c275e4c1c5ce6e9870911384bff5ca31ab63d19c76811cb1fb162678213"}, + {file = "black-24.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:b5991d523eee14756f3c8d5df5231550ae8993e2286b8014e2fdea7156ed0959"}, + {file = "black-24.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c45f8dff244b3c431b36e3224b6be4a127c6aca780853574c00faf99258041eb"}, + {file = "black-24.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6905238a754ceb7788a73f02b45637d820b2f5478b20fec82ea865e4f5d4d9f7"}, + {file = "black-24.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7de8d330763c66663661a1ffd432274a2f92f07feeddd89ffd085b5744f85e7"}, + {file = "black-24.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:7bb041dca0d784697af4646d3b62ba4a6b028276ae878e53f6b4f74ddd6db99f"}, + {file = "black-24.3.0-py3-none-any.whl", hash = "sha256:41622020d7120e01d377f74249e677039d20e6344ff5851de8a10f11f513bf93"}, + {file = "black-24.3.0.tar.gz", hash = "sha256:a0c9c4a0771afc6919578cec71ce82a3e31e054904e7197deacbc9382671c41f"}, ] [package.dependencies] @@ -3957,4 +3957,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10.0" -content-hash = "23d1ec00a721b99ee0b723159611685dc551d282838a75c6c44849a89b538bc6" +content-hash = "1804b1177e8e43c26f134dc41b187ee885f94f08a43d044a3744a26121be175d" diff --git a/pyproject.toml b/pyproject.toml index 5eb866c..2987c31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ einops = "^0.7.0" [tool.poetry.dev-dependencies] # Everything below here is alphabetically sorted bandit = "^1.7.5" -black = "^23.3.0" +black = "^24.3.0" detect-secrets = "1.2.0" flake8 = "5.0.4" flake8-bugbear = "^23.3.12" @@ -61,7 +61,6 @@ huggingface-hub = {extras = ["cli"], version = "^0.20.3"} [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" - ############ # ✅ Linters ############ @@ -90,6 +89,25 @@ min-similarity-lines = 150 max-statements = 89 max-args = 22 max-branches = 17 +disable= [ + "fixme", + "invalid-name", # disable for now, will fix later in patchscope + "line-too-long", # already handled by black + "locally-disabled", + "logging-fstring-interpolation", + "missing-class-docstring", + "missing-function-docstring", + "missing-module-docstring", + "no-else-return", + "no-member", # disable for now, will fix later in patchscope_base + "protected-access", + "suppressed-message", + "too-few-public-methods", + "too-many-instance-attributes", # already handled by black + "too-many-public-methods", + "use-dict-literal", + "attribute-defined-outside-init", # disable for now, will fix in lenses.py + ] # good-names = [] # disable = [] logging-format-style = "new" diff --git a/scripts/activation_patching_ioi.py b/scripts/activation_patching_ioi.py index 9e515c7..60fef99 100644 --- a/scripts/activation_patching_ioi.py +++ b/scripts/activation_patching_ioi.py @@ -16,27 +16,30 @@ """ -from obvs.patchscope import SourceContext, TargetContext, Patchscope +from __future__ import annotations + +from obvs.patchscope import Patchscope, SourceContext, TargetContext from obvs.vis import create_heatmap # define metric def ioi_metric(patched_logits, clean_logits, corrupt_logits, correct_idx, incorrect_idx): - """ Metric for checking correctness of indirect object identification - the normalized_logit_diff is constructed so that it is 1, if the output logits are the - same as in the clean run, and 0 if the output logits are the same as in the corrupted run - it increases, if the patched activation contribute in making the output logits - more like in the clean run + """Metric for checking correctness of indirect object identification + the normalized_logit_diff is constructed so that it is 1, if the output logits are the + same as in the clean run, and 0 if the output logits are the same as in the corrupted run + it increases, if the patched activation contribute in making the output logits + more like in the clean run """ # get the difference in logits of the correct token and the incorrect token at the last # position - patched_logit_diff = (patched_logits[-1, correct_idx] - patched_logits[-1, incorrect_idx]) - clean_logit_diff = (clean_logits[-1, correct_idx] - clean_logits[-1, incorrect_idx]) - corrupt_logit_diff = (corrupt_logits[-1, correct_idx] - corrupt_logits[-1, incorrect_idx]) + patched_logit_diff = patched_logits[-1, correct_idx] - patched_logits[-1, incorrect_idx] + clean_logit_diff = clean_logits[-1, correct_idx] - clean_logits[-1, incorrect_idx] + corrupt_logit_diff = corrupt_logits[-1, correct_idx] - corrupt_logits[-1, incorrect_idx] return (patched_logit_diff - corrupt_logit_diff) / (clean_logit_diff - corrupt_logit_diff) + # setup # the clean prompt produces our baseline answer, the corrupted prompt will be patched with # activations from the clean prompt run later on @@ -44,8 +47,8 @@ def ioi_metric(patched_logits, clean_logits, corrupt_logits, correct_idx, incorr corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to" # setup patchscope -source_context = SourceContext(prompt=clean_prompt, model_name='gpt2') -target_context = TargetContext(prompt=corrupted_prompt, model_name='gpt2', max_new_tokens=1) +source_context = SourceContext(prompt=clean_prompt, model_name="gpt2") +target_context = TargetContext(prompt=corrupted_prompt, model_name="gpt2", max_new_tokens=1) patchscope = Patchscope(source_context, target_context) @@ -72,12 +75,10 @@ def ioi_metric(patched_logits, clean_logits, corrupt_logits, correct_idx, incorr # loop over all layers of interest for layer in range(n_layers): - layer_metrics = [] # loop over all token positions for pos in range(len(patchscope.source_tokens)): - # set the layer and position for patching patchscope.source.layer = layer patchscope.target.layer = layer @@ -89,17 +90,28 @@ def ioi_metric(patched_logits, clean_logits, corrupt_logits, correct_idx, incorr # get the patched logits and calculate the logit difference patched_logits = patchscope.logits() - layer_metrics.append(ioi_metric(patched_logits, clean_logits, corrupted_logits, - correct_index, incorrect_index).item()) + layer_metrics.append( + ioi_metric( + patched_logits, + clean_logits, + corrupted_logits, + correct_index, + incorrect_index, + ).item(), + ) metrics.append(layer_metrics) fig = create_heatmap( - patchscope.source_tokens, list(range(n_layers)), metrics, x_label='Token', y_label='Layer', - title='Normalized logit difference after activation patching by layer and position' + patchscope.source_tokens, + list(range(n_layers)), + metrics, + x_label="Token", + y_label="Layer", + title="Normalized logit difference after activation patching by layer and position", ) fig.show() -fig.write_html('activation_patching_ioi_results_layer_pos.html') +fig.write_html("activation_patching_ioi_results_layer_pos.html") # Create logit diff by layer and head @@ -111,7 +123,6 @@ def ioi_metric(patched_logits, clean_logits, corrupt_logits, correct_idx, incorr # loop over all heads for head in range(n_heads): - # set the layer and position for patching patchscope.source.layer = layer patchscope.target.layer = layer @@ -123,16 +134,25 @@ def ioi_metric(patched_logits, clean_logits, corrupt_logits, correct_idx, incorr # get the patched logits and calculate the logit difference patched_logits = patchscope.logits() - layer_metrics.append(ioi_metric(patched_logits, clean_logits, corrupted_logits, - correct_index, incorrect_index).item()) + layer_metrics.append( + ioi_metric( + patched_logits, + clean_logits, + corrupted_logits, + correct_index, + incorrect_index, + ).item(), + ) metrics.append(layer_metrics) fig = create_heatmap( - list(range(head)), list(range(n_layers)), metrics, x_label='Head', y_label='Layer', - title='Normalized logit difference after activation patching by layer and head' + list(range(head)), + list(range(n_layers)), + metrics, + x_label="Head", + y_label="Layer", + title="Normalized logit difference after activation patching by layer and head", ) fig.show() -fig.write_html('activation_patching_ioi_results_layer_head.html') - - +fig.write_html("activation_patching_ioi_results_layer_head.html") diff --git a/scripts/future_lens.py b/scripts/future_lens.py index 4cd0784..f3cabe5 100644 --- a/scripts/future_lens.py +++ b/scripts/future_lens.py @@ -1,12 +1,19 @@ +from __future__ import annotations + import os + import torch -from obvs.patchscope import Patchscope, SourceContext, TargetContext +from obvs.patchscope import Patchscope, SourceContext, TargetContext -MODEL_NAME = 'EleutherAI/gpt-j-6b' +MODEL_NAME = "EleutherAI/gpt-j-6b" PREFIX_PATH = os.path.join( os.path.dirname(__file__), - "..", "data", "processed", "gptj_soft_prefix.pt") + "..", + "data", + "processed", + "gptj_soft_prefix.pt", +) DEVICE = "auto" @@ -22,15 +29,17 @@ def future_lens(): layer=-1, position=-1, model_name=MODEL_NAME, - device=DEVICE) + device=DEVICE, + ) target = TargetContext( - prompt=soft_prompt[None,:], + prompt=soft_prompt[None, :], layer=-1, position=-1, model_name=MODEL_NAME, device=DEVICE, - max_new_tokens=4) + max_new_tokens=4, + ) # Might need GPU to load gptj patchscope = Patchscope(source, target) diff --git a/scripts/reproduce_logitlens_results.py b/scripts/reproduce_logitlens_results.py index 81ee678..35ea4d9 100644 --- a/scripts/reproduce_logitlens_results.py +++ b/scripts/reproduce_logitlens_results.py @@ -9,6 +9,8 @@ ) """ +from __future__ import annotations + from obvs.lenses import ClassicLogitLens, PatchscopeLogitLens prompt = """Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training @@ -17,7 +19,7 @@ thousands of examples. By contrast, humans can generally perform a new language task from only a few examples or from simple instructions – something which current NLP systems still largely struggle to do. Here we show that scaling up language models greatly improves task-agnostic, -few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art +few-shot performance, sometimes even reaching competitiveness with prior state-of-the-art finetuning approaches. Specifically, we train GPT-3, an autoregressive language model with 175 billion parameters, 10x more than any previous non-sparse language model, and test its performance in the few-shot setting. For all tasks, GPT-3 is applied without any gradient updates or fine-tuning, @@ -29,20 +31,27 @@ datasets where GPT-3 faces methodological issues related to training on large web corpora. Finally, we find that GPT-3 can generate samples of news articles which human evaluators have difficulty distinguishing from articles written by humans. We discuss broader societal impacts of this finding -and of GPT-3 in general.""".replace("\n", " ") +and of GPT-3 in general.""".replace( + "\n", + " ", +) -substring = "Specifically, we train GPT-3, an autoregressive language model with 175 billion " \ - "parameters" +substring = ( + "Specifically, we train GPT-3, an autoregressive language model with 175 billion parameters" +) layers = list(range(0, 12)) # models: gpt2 125m, gpt2 1B, gpt-neo 125m -for model_name in ['gpt2', 'EleutherAI/gpt-neo-125M', 'gpt2-xl']: - +for model_name in ["gpt2", "EleutherAI/gpt-neo-125M", "gpt2-xl"]: # run on both, classic and Patschcope logit lens - for ll_type, ll_class in [('patchscope_logit_lens', PatchscopeLogitLens), - ('classic_logit_lens', ClassicLogitLens)]: - ll = ll_class(model_name, prompt, 'auto') + for ll_type, ll_class in [ + ("patchscope_logit_lens", PatchscopeLogitLens), + ("classic_logit_lens", ClassicLogitLens), + ]: + ll = ll_class(model_name, prompt, "auto") ll.run(substring, layers) fig = ll.visualize() - fig.write_html(f'{model_name.replace("-", "_").replace("/", "_").lower()}_{ll_type}_logits_top_preds.html') + fig.write_html( + f'{model_name.replace("-", "_").replace("/", "_").lower()}_{ll_type}_logits_top_preds.html', + ) diff --git a/scripts/token_identity_prompts.py b/scripts/token_identity_prompts.py index bb25974..5a88be4 100644 --- a/scripts/token_identity_prompts.py +++ b/scripts/token_identity_prompts.py @@ -137,9 +137,6 @@ def main(model_name, target_prompt, samples, full=False): ) args = parser.parse_args() - # for prompt in prompts: - # main(args.model_name, prompt, args.n, args.full) - samples = [] for example in shuffled_dataset.take(args.n): samples.append(example["text"]) diff --git a/setup.cfg b/setup.cfg index 1cfba67..96ec6e9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,6 +6,20 @@ extend-ignore = E203, # whitespace before : is not PEP8 compliant & conflicts with black T100, # line contains FIXME T101, # line contains TODO -# per-file-ignores = - # src/path/file.py" - # E123 + D10, # missing docstring * + D2, # docstring whitespaces + D40, # docstring styles + E501, # line too long + R504, # unnecessary variable assignment before return statement. + R505, # unnecessary else after return statement. + C408, # Unnecessary dict call - rewrite as a literal. + SIM117,# Use single with statement instead of multiple with statements + +per-file-ignores = + # INP001: File is part of an implicit namespace package. Add an __init__.py? + # T201 print found. + scripts/*: INP001, T201 + + # INP001: File is part of an implicit namespace package. Add an __init__.py? + # T201 print found. + benchmarking/*: INP001, T201 diff --git a/tests/conftest.py b/tests/conftest.py index 1ac13f9..cb61767 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,7 @@ def patchscope_llama(): return Patchscope(source_context, target_context) +# pylint: disable=redefined-outer-name @pytest.fixture(autouse=True) def reset_patchscope_fixtures(patchscope: Patchscope): # TODO: should reset patchscope_llama fixture once it works in tests diff --git a/tests/test_patchscopes.py b/tests/test_patchscopes.py index 248d3ce..1257c40 100644 --- a/tests/test_patchscopes.py +++ b/tests/test_patchscopes.py @@ -8,7 +8,6 @@ class TestContext: - @staticmethod def test_source_context_init(): source = SourceContext("source") @@ -39,10 +38,10 @@ def test_soft_prompt_dimensions_must_be_two_or_three(): SourceContext(prompt=torch.ones((1,))) SourceContext(prompt=torch.ones(1, 2)) - SourceContext(prompt=torch.ones((1,2,3))) + SourceContext(prompt=torch.ones((1, 2, 3))) with pytest.raises(ValueError): - SourceContext(prompt=torch.ones((1,2,3,4))) + SourceContext(prompt=torch.ones((1, 2, 3, 4))) @staticmethod def test_prompt_type_must_be_str_or_tensor(): @@ -53,7 +52,6 @@ def test_prompt_type_must_be_str_or_tensor(): SourceContext(prompt=5) - class TestPatchscope: @staticmethod def test_patchscope_init(): @@ -94,7 +92,9 @@ def test_patchscope_map_transpose(): @staticmethod def test_source_tokens(patchscope): patchscope.source.prompt = "a dog is a dog. a cat is a" - assert patchscope.source_token_ids == patchscope.tokenizer.encode("a dog is a dog. a cat is a") + assert patchscope.source_token_ids == patchscope.tokenizer.encode( + "a dog is a dog. a cat is a", + ) @staticmethod def test_source_forward_pass_creates_hidden_state(patchscope): @@ -133,12 +133,11 @@ def test_source_forward_pass_with_attention_head_patching(patchscope): patchscope.source.head = 0 patchscope.source_forward_pass() - assert patchscope._source_hidden_state.value.shape[0] == 1 # Batch size, always 1 - assert patchscope._source_hidden_state.value.shape[1] == len( + assert patchscope._source_hidden_state.shape[0] == 1 # Batch size, always 1 + assert patchscope._source_hidden_state.shape[1] == len( patchscope.source_tokens, ) # Number of tokens assert ( - patchscope._source_hidden_state.value.shape[2] + patchscope._source_hidden_state.shape[2] == patchscope.source_model.transformer.h[0].attn.head_dim ) # Head dimension - diff --git a/tests/test_patchscopes_high_level.py b/tests/test_patchscopes_high_level.py index 9b1ed9e..60dad7b 100644 --- a/tests/test_patchscopes_high_level.py +++ b/tests/test_patchscopes_high_level.py @@ -2,8 +2,6 @@ import pytest -from obvs.patchscope import ModelLoader - class TestPatchscope: @staticmethod @@ -164,7 +162,6 @@ def test_multi_token_generation_with_different_lengths_single_patch(patchscope): # Assert the target has been patched to think a rat is a cat assert "cat" in patchscope.full_output() - @staticmethod def test_soft_prompt(patchscope): soft_prompt = None @@ -175,7 +172,9 @@ def test_soft_prompt(patchscope): patchscope.source.position = -1 patchscope.source.layer = -1 - patchscope.target.prompt = " ".join("_" * soft_prompt.shape[1]) # works for gpt2 & gptj, not sure about others + patchscope.target.prompt = " ".join( + "_" * soft_prompt.shape[1], + ) # works for gpt2 & gptj, not sure about others patchscope.target.position = -1 patchscope.target.layer = -1 patchscope.target.max_new_tokens = 4 @@ -184,7 +183,6 @@ def test_soft_prompt(patchscope): assert "cat" in patchscope.output()[-1] - @staticmethod @pytest.mark.skip(reason="This doesn't work") def test_token_identity_prompt_early(patchscope):