You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug sae_lens_loader in pretrained_sae_loaders.py currently does not pass in dtype into read_sae_from_disk, so it always initializes weights using float32 first.
The dtype is correct at the very end, since sae is created/initialized with the correct dtype from cfg_dict. However, ideally we would be using the correct dtype the whole time.
The fix seems to be that read_sae_from_disk should 1) make dtype optional and 2) use the dtype specified by the cfg_dict by default, if dtype is not specified. This has the benefit that other methods calling read_sae_from_disk would have the correct behavior by default, unless they choose to override.
I'll make a PR for this soon.
Checklist
I have checked that there is no similar issue in the repo (required)
The text was updated successfully, but these errors were encountered:
Describe the bug
sae_lens_loader
inpretrained_sae_loaders.py
currently does not pass indtype
intoread_sae_from_disk
, so it always initializes weights usingfloat32
first.The dtype is correct at the very end, since
sae
is created/initialized with the correct dtype from cfg_dict. However, ideally we would be using the correct dtype the whole time.The fix seems to be that
read_sae_from_disk
should 1) make dtype optional and 2) use the dtype specified by the cfg_dict by default, if dtype is not specified. This has the benefit that other methods callingread_sae_from_disk
would have the correct behavior by default, unless they choose to override.I'll make a PR for this soon.
Checklist
The text was updated successfully, but these errors were encountered: