From bbe1eaae32d61c7b9fdd70b2978a584403532dc8 Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Tue, 13 Aug 2024 21:12:58 -0400 Subject: [PATCH 1/3] Upstream commit update (#703) From 73da2b6fa93f0a9db3f34b27c6fccc062519735b Mon Sep 17 00:00:00 2001 From: Oliver Daniels-Koch <40397426+oliveradk@users.noreply.github.com> Date: Thu, 15 Aug 2024 19:40:26 -0400 Subject: [PATCH 2/3] removed einsum causing error when use_atten_result is enabled (#660) * removed einsum causing error when use_atten_result is enabled * removed fancy einsum (readded accidently) * removed extra space --------- Co-authored-by: Bryce Meyer --- tests/unit/test_use_attn_result.py | 73 ++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 tests/unit/test_use_attn_result.py diff --git a/tests/unit/test_use_attn_result.py b/tests/unit/test_use_attn_result.py new file mode 100644 index 000000000..417374f59 --- /dev/null +++ b/tests/unit/test_use_attn_result.py @@ -0,0 +1,73 @@ +import torch + +from transformer_lens import HookedTransformer +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def test_atten_result_normal_attn_correct(): + """Verifies that the attn_result flag does not change the output for models with normal attention.""" + d_model = 128 + d_head = 8 + n_heads = 16 + n_ctx = 128 + n_layers = 1 + d_vocab = 10 + + cfg = HookedTransformerConfig( + d_model=d_model, + d_head=d_head, + n_heads=n_heads, + n_ctx=n_ctx, + n_layers=n_layers, + attn_only=True, + d_vocab=d_vocab, + ) + + model = HookedTransformer(cfg) + assert model.cfg.use_split_qkv_input is False + + x = torch.arange(1, 9).unsqueeze(0) + normal_output = model(x) + + model.set_use_attn_result(True) + assert model.cfg.use_attn_result is True + + split_output = model(x) + + assert torch.allclose(normal_output, split_output, atol=1e-6) + + +def test_atten_result_grouped_query_attn_correct(): + """Verifies that the atten_result flag does not change the output for models with grouped query attention.""" + + d_model = 128 + d_head = 8 + n_heads = 16 + n_ctx = 128 + n_key_value_heads = 2 + n_layers = 1 + d_vocab = 10 + + cfg = HookedTransformerConfig( + d_model=d_model, + d_head=d_head, + n_heads=n_heads, + n_ctx=n_ctx, + n_key_value_heads=n_key_value_heads, + n_layers=n_layers, + attn_only=True, + d_vocab=d_vocab, + ) + + model = HookedTransformer(cfg) + assert model.cfg.use_split_qkv_input is False + + x = torch.arange(1, 9).unsqueeze(0) + normal_output = model(x) + + model.set_use_attn_result(True) + assert model.cfg.use_attn_result is True + + split_output = model(x) + + assert torch.allclose(normal_output, split_output, atol=1e-6) From db1a7f51ebe44fe40f39838f8079d2d8784d603b Mon Sep 17 00:00:00 2001 From: Bryce Meyer Date: Thu, 5 Sep 2024 12:05:45 -0500 Subject: [PATCH 3/3] revised loading to recycle state dict (#706) * revised loading to recycle state dict * removed manuall gc collection --- transformer_lens/HookedTransformer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 8ee2e74f8..8e32a9cd6 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -8,7 +8,6 @@ alteration of activations in individual components like attention heads and MLP layers, facilitating a deeper understanding of the internal workings of transformers like GPT-2. """ - import logging import os from typing import Dict, List, NamedTuple, Optional, Tuple, Union, cast, overload @@ -1570,7 +1569,10 @@ def load_and_process_state_dict( # so that quantization settings are not lost self.load_state_dict(state_dict, assign=True, strict=False) else: - self.load_state_dict(state_dict, strict=False) + state_dict_keys = list(state_dict.keys()) + for key in state_dict_keys: + self.load_state_dict({key: state_dict[key]}, strict=False) + del state_dict[key] def fill_missing_keys(self, state_dict): return loading.fill_missing_keys(self, state_dict)