diff --git a/demos/Colab_Compatibility.ipynb b/demos/Colab_Compatibility.ipynb index 2fa9838eb..2f28a37a1 100644 --- a/demos/Colab_Compatibility.ipynb +++ b/demos/Colab_Compatibility.ipynb @@ -58,14 +58,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "TransformerLens currently supports 205 models out of the box.\n" + "TransformerLens currently supports 206 models out of the box.\n" ] } ], @@ -429,6 +429,7 @@ " \"meta-llama/Llama-2-70b-chat-hf\",\n", " \"meta-llama/Llama-3.1-70B\",\n", " \"meta-llama/Llama-3.1-70B-Instruct\",\n", + " \"meta-llama/Llama-3.3-70B-Instruct\",\n", " \"meta-llama/Meta-Llama-3-70B\",\n", " \"meta-llama/Meta-Llama-3-70B-Instruct\",\n", " \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n", diff --git a/tests/integration/test_hooks.py b/tests/integration/test_hooks.py index 6a9880a67..29d5ff9ed 100644 --- a/tests/integration/test_hooks.py +++ b/tests/integration/test_hooks.py @@ -234,3 +234,10 @@ def set_to_randn(z, hook): # exactly when the zero hook is attached last XOR it is prepended assert torch.allclose(logits, model.unembed.b_U[None, :]) == logits_are_unembed_bias + + +def test_use_attn_in_with_gqa_raises_error(): + # Create model that uses GroupedQueryAttention + model = HookedTransformer.from_pretrained("Qwen/Qwen2-0.5B") + with pytest.raises(AssertionError): + model.set_use_attn_in(True) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 500098c32..a34a5c4a0 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1311,7 +1311,7 @@ def from_pretrained( center_writing_weights = False if center_unembed and cfg.output_logits_soft_cap > 0.0: logging.warning( - "You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant" + "You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant " "Setting center_unembed=False instead." ) center_unembed = False @@ -1969,6 +1969,9 @@ def set_use_attn_in(self, use_attn_in: bool): """ Toggles whether to allow editing of inputs to each attention head. """ + assert ( + self.cfg.n_key_value_heads is None + ), "Can't use attn_in with GroupedQueryAttention, please use split_qkv_input instead" self.cfg.use_attn_in = use_attn_in def set_ungroup_grouped_query_attention(self, ungroup_grouped_query_attention: bool): diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index b4ecc8d64..17d32e8c7 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -151,14 +151,15 @@ "meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-70B", "meta-llama/Meta-Llama-3-70B-Instruct", - "meta-llama/Llama-3.2-1B", - "meta-llama/Llama-3.2-3B", - "meta-llama/Llama-3.2-1B-Instruct", - "meta-llama/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.1-70B", "meta-llama/Llama-3.1-8B", "meta-llama/Llama-3.1-8B-Instruct", "meta-llama/Llama-3.1-70B-Instruct", + "meta-llama/Llama-3.2-1B", + "meta-llama/Llama-3.2-3B", + "meta-llama/Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.2-3B-Instruct", + "meta-llama/Llama-3.3-70B-Instruct", "Baidicoot/Othello-GPT-Transformer-Lens", "bert-base-cased", "roneneldan/TinyStories-1M", @@ -960,6 +961,30 @@ def convert_hf_model_config(model_name: str, **kwargs): "NTK_by_parts_high_freq_factor": 4.0, "NTK_by_parts_factor": 32.0, } + elif "Llama-3.3-70B" in official_model_name: + cfg_dict = { + "d_model": 8192, + "d_head": 128, + "n_heads": 64, + "d_mlp": 28672, + "n_layers": 80, + "n_ctx": 2048, # capped due to memory issues + "eps": 1e-5, + "d_vocab": 128256, + "act_fn": "silu", + "n_key_value_heads": 8, + "normalization_type": "RMS", + "positional_embedding_type": "rotary", + "rotary_adjacent_pairs": False, + "rotary_dim": 32, + "final_rms": True, + "gated_mlp": True, + "rotary_base": 500000.0, + "use_NTK_by_parts_rope": True, + "NTK_by_parts_low_freq_factor": 1.0, + "NTK_by_parts_high_freq_factor": 4.0, + "NTK_by_parts_factor": 8.0, + } elif "Llama-3.1-8B" in official_model_name: cfg_dict = { "d_model": 4096, @@ -1241,6 +1266,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "trust_remote_code": True, "final_rms": True, "gated_mlp": True, + "default_prepend_bos": False, } elif architecture == "Qwen2ForCausalLM": # Note that Qwen1.5 models have architecture type Qwen2ForCausalLM. @@ -1259,12 +1285,13 @@ def convert_hf_model_config(model_name: str, **kwargs): "initializer_range": hf_config.initializer_range, "normalization_type": "RMS", "positional_embedding_type": "rotary", - "rotary_base": hf_config.rope_theta, + "rotary_base": int(hf_config.rope_theta), "rotary_adjacent_pairs": False, "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, "tokenizer_prepends_bos": True, "final_rms": True, "gated_mlp": True, + "default_prepend_bos": False, } elif architecture == "PhiForCausalLM": # Architecture for microsoft/phi models @@ -1325,7 +1352,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "act_fn": "gelu_new", "initializer_range": 0.02, "normalization_type": "RMS", - "rotary_base": 10000.0, + "rotary_base": 10000, "rotary_dim": 256, "positional_embedding_type": "rotary", "use_attn_scale": True,