Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Report] read_sae_from_disk should default to the SAE dtype unless it's overridden #230

Open
1 task done
hijohnnylin opened this issue Jul 13, 2024 · 0 comments
Open
1 task done
Labels
good first issue Good for newcomers

Comments

@hijohnnylin
Copy link
Collaborator

Describe the bug
sae_lens_loader in pretrained_sae_loaders.py currently does not pass in dtype into read_sae_from_disk, so it always initializes weights using float32 first.

The dtype is correct at the very end, since sae is created/initialized with the correct dtype from cfg_dict. However, ideally we would be using the correct dtype the whole time.

The fix seems to be that read_sae_from_disk should 1) make dtype optional and 2) use the dtype specified by the cfg_dict by default, if dtype is not specified. This has the benefit that other methods calling read_sae_from_disk would have the correct behavior by default, unless they choose to override.

I'll make a PR for this soon.

Checklist

  • I have checked that there is no similar issue in the repo (required)
@jbloomAus jbloomAus added the good first issue Good for newcomers label Sep 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

2 participants