From d8d41ddffdbece6b29b4bfd2f8db169d5ae06f8e Mon Sep 17 00:00:00 2001 From: callummcdougall Date: Fri, 4 Oct 2024 10:52:06 +0100 Subject: [PATCH 1/2] first commit --- transformer_lens/HookedTransformer.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index a5c53b222..3fdd1c1ed 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -10,7 +10,18 @@ """ import logging import os -from typing import Dict, List, NamedTuple, Optional, Tuple, Union, cast, overload +from typing import ( + Dict, + List, + NamedTuple, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) import einops import numpy as np @@ -67,6 +78,8 @@ "bf16": torch.bfloat16, } +T = TypeVar("T", bound="HookedTransformer") + class Output(NamedTuple): """Output Named Tuple. @@ -1053,7 +1066,7 @@ def move_model_modules_to_device(self): @classmethod def from_pretrained( - cls, + cls: Type[T], model_name: str, fold_ln: bool = True, center_writing_weights: bool = True, @@ -1072,7 +1085,7 @@ def from_pretrained( dtype="float32", first_n_layers: Optional[int] = None, **from_pretrained_kwargs, - ) -> "HookedTransformer": + ) -> T: """Load in a Pretrained Model. Load in pretrained model weights to the HookedTransformer format and optionally to do some From 64d0e5f951d385031397bcd8cdd5766e5c26e0cc Mon Sep 17 00:00:00 2001 From: Albert Garde Date: Sat, 5 Oct 2024 00:39:13 +0200 Subject: [PATCH 2/2] Avoid warning in `utils.download_file_from_hf` (#739) Add `weights_only=False` argument to `torch.load`. Starting in `torch=2.4.1`, `torch.load` prints a warning when the `weights_only` argument is not given. --- transformer_lens/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index 421d35e15..d68bc561f 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -59,7 +59,7 @@ def download_file_from_hf( ) if file_path.endswith(".pth") or force_is_torch: - return torch.load(file_path, map_location="cpu") + return torch.load(file_path, map_location="cpu", weights_only=False) elif file_path.endswith(".json"): return json.load(open(file_path, "r")) else: