Skip to content

Commit

Permalink
fix: patches input grad toggling based on incoming batch
Browse files Browse the repository at this point in the history
This adjusts the logic, albeit maybe inconsistent with the rest of multitask, where
we check the incoming batch for dataset names at the top level to determine if it's
a multidata batch, instead of relying on the model expectations.

This fixes the ase calculate behavior, which would have been mismatched since the
module is inherently multidata but the incoming batch is not.
  • Loading branch information
laserkelvin committed May 24, 2024
1 parent 2f8894c commit 6e54dad
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion matsciml/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,7 +2357,9 @@ def _toggle_input_grads(
"""
need_grad_keys = getattr(self, "input_grad_keys", None)
if need_grad_keys is not None:
if self.is_multidata:
# we determine if it's multidata based on the incoming batch
# as it should have dataset in its key
if any(["Dataset" in key for key in batch.keys()]):
# if this is a multidataset task, loop over each dataset
# and enable gradients for the inputs that need them
for dset_name, data in batch.items():
Expand Down

0 comments on commit 6e54dad

Please sign in to comment.