diff --git a/tests/unit/test_use_attn_result.py b/tests/unit/test_use_attn_result.py new file mode 100644 index 000000000..417374f59 --- /dev/null +++ b/tests/unit/test_use_attn_result.py @@ -0,0 +1,73 @@ +import torch + +from transformer_lens import HookedTransformer +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def test_atten_result_normal_attn_correct(): + """Verifies that the attn_result flag does not change the output for models with normal attention.""" + d_model = 128 + d_head = 8 + n_heads = 16 + n_ctx = 128 + n_layers = 1 + d_vocab = 10 + + cfg = HookedTransformerConfig( + d_model=d_model, + d_head=d_head, + n_heads=n_heads, + n_ctx=n_ctx, + n_layers=n_layers, + attn_only=True, + d_vocab=d_vocab, + ) + + model = HookedTransformer(cfg) + assert model.cfg.use_split_qkv_input is False + + x = torch.arange(1, 9).unsqueeze(0) + normal_output = model(x) + + model.set_use_attn_result(True) + assert model.cfg.use_attn_result is True + + split_output = model(x) + + assert torch.allclose(normal_output, split_output, atol=1e-6) + + +def test_atten_result_grouped_query_attn_correct(): + """Verifies that the atten_result flag does not change the output for models with grouped query attention.""" + + d_model = 128 + d_head = 8 + n_heads = 16 + n_ctx = 128 + n_key_value_heads = 2 + n_layers = 1 + d_vocab = 10 + + cfg = HookedTransformerConfig( + d_model=d_model, + d_head=d_head, + n_heads=n_heads, + n_ctx=n_ctx, + n_key_value_heads=n_key_value_heads, + n_layers=n_layers, + attn_only=True, + d_vocab=d_vocab, + ) + + model = HookedTransformer(cfg) + assert model.cfg.use_split_qkv_input is False + + x = torch.arange(1, 9).unsqueeze(0) + normal_output = model(x) + + model.set_use_attn_result(True) + assert model.cfg.use_attn_result is True + + split_output = model(x) + + assert torch.allclose(normal_output, split_output, atol=1e-6) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 8ee2e74f8..8e32a9cd6 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -8,7 +8,6 @@ alteration of activations in individual components like attention heads and MLP layers, facilitating a deeper understanding of the internal workings of transformers like GPT-2. """ - import logging import os from typing import Dict, List, NamedTuple, Optional, Tuple, Union, cast, overload @@ -1570,7 +1569,10 @@ def load_and_process_state_dict( # so that quantization settings are not lost self.load_state_dict(state_dict, assign=True, strict=False) else: - self.load_state_dict(state_dict, strict=False) + state_dict_keys = list(state_dict.keys()) + for key in state_dict_keys: + self.load_state_dict({key: state_dict[key]}, strict=False) + del state_dict[key] def fill_missing_keys(self, state_dict): return loading.fill_missing_keys(self, state_dict)