chore: Test JumpReLU/Gated SAE and refactor sae forward #328
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR adds test coverage to the JumpReLU and GatedSAE encode methods.
In doing this, I realized there's a lot of duplication between all the encode variants and cleaned that up as well. I think there were some potential minor bugs in this duplication, for instance in
forward()
, when adding error term, we calledrun_time_activation_norm_fn_out()
afterreshape_fn_out()
, but we do the opposite indecode()
.It looks like most of the duplication between the
error
section offorward()
and the normalencode() / decode()
was there just to avoid triggering hooks, so I added a contextmanager_disable_hooks()
which is used to disable hooks in this branch of the code while reusing our existingencode() / decode()
methods. This should mean we don't need to worry about these duplicated codepaths being slightly different.If the refactor is out of scope, I can revert the changes to
sae.py
and just leave the test coverage.Fixes #323
Type of change
Please delete options that are not relevant.
Checklist:
You have tested formatting, typing and unit tests (acceptance tests not currently in use)
make check-ci
to check format and linting. (you can runmake format
to format code if needed.)