diff --git a/matsciml/models/base.py b/matsciml/models/base.py index ff69986c..54e86edd 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -2357,7 +2357,9 @@ def _toggle_input_grads( """ need_grad_keys = getattr(self, "input_grad_keys", None) if need_grad_keys is not None: - if self.is_multidata: + # we determine if it's multidata based on the incoming batch + # as it should have dataset in its key + if any(["Dataset" in key for key in batch.keys()]): # if this is a multidataset task, loop over each dataset # and enable gradients for the inputs that need them for dset_name, data in batch.items():