Skip to content

Commit

Permalink
added kwargs to load_state_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasGebauer committed May 8, 2023
1 parent 30872c8 commit acbe70e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/schnetpack_gschnet/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def setup(self, stage=None):
if stage == "fit":
self.model.initialize_transforms(dm)

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def load_state_dict(self, state_dict: Dict[str, Any], **kwargs) -> None:
# make sure that cutoff values have not been changed
for name, val1 in [
("model_cutoff", self.model.model_cutoff),
Expand All @@ -112,7 +112,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
f"{name}. Please set it to {val2:.2f} or train a new model."
)
# load checkpoint
super().load_state_dict(state_dict)
super().load_state_dict(state_dict, **kwargs)

def loss_fn(self, pred, batch, return_individual_losses=False):
# calculate loss on type predictions (NLL loss using atomic types as classes)
Expand Down

0 comments on commit acbe70e

Please sign in to comment.