Skip to content

Commit

Permalink
Implement hook_mlp_in for parallel attention/MLP models (#380)
Browse files Browse the repository at this point in the history
* Implement attention

* And add some more safety
  • Loading branch information
ArthurConmy authored Sep 10, 2023
1 parent f68a7fb commit 3745d0c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
2 changes: 2 additions & 0 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,8 @@ def set_use_hook_mlp_in(self, use_hook_mlp_in: bool):
"""
Toggles whether to allow storing and editing inputs to each MLP layer.
"""

assert not self.cfg.attn_only, "Can't use hook_mlp_in with attn_only model"
self.cfg.use_hook_mlp_in = use_hook_mlp_in

def process_weights_(
Expand Down
6 changes: 5 additions & 1 deletion transformer_lens/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,11 @@ def add_head_dimension(tensor):
elif self.cfg.parallel_attn_mlp:
# Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used.
# In GPT-J, LN1 and LN2 are tied, in GPT-NeoX they aren't.
normalized_resid_pre_2 = self.ln2(resid_pre)
normalized_resid_pre_2 = self.ln2(
resid_pre
if not self.cfg.use_hook_mlp_in
else self.hook_mlp_in(resid_pre.clone())
)
mlp_out = self.hook_mlp_out(
self.mlp(normalized_resid_pre_2)
) # [batch, pos, d_model]
Expand Down

0 comments on commit 3745d0c

Please sign in to comment.