Skip to content

Commit

Permalink
minor fix on unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Jan 11, 2024
1 parent d2ad05c commit bacdb42
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion models/gru/modelings_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def forward(
return SequenceClassifierOutput(
loss=loss,
logits=pooled_logits,
hidden_states=mlp_outputs.hidden_states,
hidden_states=gru_outputs.hidden_states,
)


Expand Down
5 changes: 4 additions & 1 deletion models/interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,10 @@ class PCARotatedSpaceIntervention(BasisAgnosticIntervention):
"""Intervention in the pca space."""

def __init__(self, embed_dim, **kwargs):
super().__init__(embed_dim, **kwargs)
super().__init__(**kwargs)
pca = kwargs["pca"]
pca_mean = kwargs["pca_mean"]
pca_std = kwargs["pca_std"]
self.pca_components = torch.nn.Parameter(
torch.tensor(pca.components_, dtype=torch.float32), requires_grad=False
)
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
)

ONE_MLP_WITH_W1_OUT_RUN = (
lambda w1_out, mlp: mlp.mlp.h[0].act(w1_act) @ mlp.score.weight.T
lambda w1_out, mlp: mlp.mlp.h[0].act(w1_out) @ mlp.score.weight.T
)

ONE_MLP_WITH_W1_ACT_RUN = lambda w1_act, mlp: w1_act @ mlp.score.weight.T
Expand Down

0 comments on commit bacdb42

Please sign in to comment.