Skip to content

Commit

Permalink
revised loading to recycle state dict (#706)
Browse files Browse the repository at this point in the history
* revised loading to recycle state dict

* removed manuall gc collection
  • Loading branch information
bryce13950 authored Sep 5, 2024
1 parent 73da2b6 commit db1a7f5
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
alteration of activations in individual components like attention heads and MLP layers, facilitating
a deeper understanding of the internal workings of transformers like GPT-2.
"""

import logging
import os
from typing import Dict, List, NamedTuple, Optional, Tuple, Union, cast, overload
Expand Down Expand Up @@ -1570,7 +1569,10 @@ def load_and_process_state_dict(
# so that quantization settings are not lost
self.load_state_dict(state_dict, assign=True, strict=False)
else:
self.load_state_dict(state_dict, strict=False)
state_dict_keys = list(state_dict.keys())
for key in state_dict_keys:
self.load_state_dict({key: state_dict[key]}, strict=False)
del state_dict[key]

def fill_missing_keys(self, state_dict):
return loading.fill_missing_keys(self, state_dict)
Expand Down

0 comments on commit db1a7f5

Please sign in to comment.