Skip to content

Commit

Permalink
remove logging and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Jul 25, 2024
1 parent 7510aa5 commit 5bee6b8
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pyreft/interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,13 @@ def state_dict(self, *args, **kwargs):
for k, v in self.learned_source.state_dict().items():
state_dict[k] = v
state_dict["rotate_layer"] = self.rotate_layer.weight.data
print(self.rotate_layer.weight.data)
return state_dict

def load_state_dict(self, state_dict, *args, **kwargs):
"""
Overwrite for data-efficiency.
"""
super().load_state_dict(state_dict, strict=False)
self.learned_source.load_state_dict(state_dict, strict=False)

# Caveat: without creating a new layer, it might not work (still not sure why)
# We have to recreate a layer, and load back the columns.
Expand All @@ -77,7 +76,8 @@ def load_state_dict(self, state_dict, *args, **kwargs):
self.embed_dim, overload_w_width, init_orth=True).to(
self.learned_source.weight.device)
self.rotate_layer = torch.nn.utils.parametrizations.orthogonal(rotate_layer)
self.rotate_layer.parametrizations.weight[0].base[:,:overload_w_width] = overload_w.to("cuda")
self.rotate_layer.parametrizations.weight[0].base[:,:overload_w_width] = overload_w
assert torch.allclose(self.rotate_layer.weight.data, overload_w.data) == True # we must match!

return

Expand Down

0 comments on commit 5bee6b8

Please sign in to comment.