Skip to content

Commit

Permalink
Merge pull request #712 from TransformerLensOrg/dev
Browse files Browse the repository at this point in the history
v2.4.1
  • Loading branch information
bryce13950 authored Sep 5, 2024
2 parents cb5017a + db1a7f5 commit dd8c1e0
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 2 deletions.
73 changes: 73 additions & 0 deletions tests/unit/test_use_attn_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch

from transformer_lens import HookedTransformer
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


def test_atten_result_normal_attn_correct():
"""Verifies that the attn_result flag does not change the output for models with normal attention."""
d_model = 128
d_head = 8
n_heads = 16
n_ctx = 128
n_layers = 1
d_vocab = 10

cfg = HookedTransformerConfig(
d_model=d_model,
d_head=d_head,
n_heads=n_heads,
n_ctx=n_ctx,
n_layers=n_layers,
attn_only=True,
d_vocab=d_vocab,
)

model = HookedTransformer(cfg)
assert model.cfg.use_split_qkv_input is False

x = torch.arange(1, 9).unsqueeze(0)
normal_output = model(x)

model.set_use_attn_result(True)
assert model.cfg.use_attn_result is True

split_output = model(x)

assert torch.allclose(normal_output, split_output, atol=1e-6)


def test_atten_result_grouped_query_attn_correct():
"""Verifies that the atten_result flag does not change the output for models with grouped query attention."""

d_model = 128
d_head = 8
n_heads = 16
n_ctx = 128
n_key_value_heads = 2
n_layers = 1
d_vocab = 10

cfg = HookedTransformerConfig(
d_model=d_model,
d_head=d_head,
n_heads=n_heads,
n_ctx=n_ctx,
n_key_value_heads=n_key_value_heads,
n_layers=n_layers,
attn_only=True,
d_vocab=d_vocab,
)

model = HookedTransformer(cfg)
assert model.cfg.use_split_qkv_input is False

x = torch.arange(1, 9).unsqueeze(0)
normal_output = model(x)

model.set_use_attn_result(True)
assert model.cfg.use_attn_result is True

split_output = model(x)

assert torch.allclose(normal_output, split_output, atol=1e-6)
6 changes: 4 additions & 2 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
alteration of activations in individual components like attention heads and MLP layers, facilitating
a deeper understanding of the internal workings of transformers like GPT-2.
"""

import logging
import os
from typing import Dict, List, NamedTuple, Optional, Tuple, Union, cast, overload
Expand Down Expand Up @@ -1570,7 +1569,10 @@ def load_and_process_state_dict(
# so that quantization settings are not lost
self.load_state_dict(state_dict, assign=True, strict=False)
else:
self.load_state_dict(state_dict, strict=False)
state_dict_keys = list(state_dict.keys())
for key in state_dict_keys:
self.load_state_dict({key: state_dict[key]}, strict=False)
del state_dict[key]

def fill_missing_keys(self, state_dict):
return loading.fill_missing_keys(self, state_dict)
Expand Down

0 comments on commit dd8c1e0

Please sign in to comment.