From a7a3b2382e582b57784974098b7be4b759b31249 Mon Sep 17 00:00:00 2001 From: Mateusz Date: Wed, 7 Aug 2024 09:18:24 +0000 Subject: [PATCH 1/2] Update Gemma2 attention scale --- transformer_lens/loading_from_pretrained.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index db33e5b98..452cb6f3b 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -1246,7 +1246,6 @@ def convert_hf_model_config(model_name: str, **kwargs): "rotary_base": 10000.0, "positional_embedding_type": "rotary", "use_attn_scale": True, - "attn_scale": math.sqrt(224), "n_key_value_heads": 4, "window_size": 4096, "use_local_attn": True, @@ -1274,7 +1273,6 @@ def convert_hf_model_config(model_name: str, **kwargs): "rotary_base": 10000.0, "positional_embedding_type": "rotary", "use_attn_scale": True, - "attn_scale": math.sqrt(224), "n_key_value_heads": 8, "window_size": 4096, "use_local_attn": True, From ccf4cacb097c85d65eade804c3ebf768bea82f99 Mon Sep 17 00:00:00 2001 From: Mateusz Date: Wed, 7 Aug 2024 19:03:21 +0000 Subject: [PATCH 2/2] remove import --- transformer_lens/loading_from_pretrained.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 452cb6f3b..7c36efdd7 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -5,7 +5,6 @@ import dataclasses import logging -import math import os import re from pathlib import Path