diff --git a/matsciml/interfaces/ase/base.py b/matsciml/interfaces/ase/base.py index 7c4838a1..9913c769 100644 --- a/matsciml/interfaces/ase/base.py +++ b/matsciml/interfaces/ase/base.py @@ -13,6 +13,7 @@ ScalarRegressionTask, GradFreeForceRegressionTask, ForceRegressionTask, + MultiTaskLitModule, ) from matsciml.datasets.transforms.base import AbstractDataTransform @@ -74,7 +75,8 @@ def __init__( self, task_module: ScalarRegressionTask | GradFreeForceRegressionTask - | ForceRegressionTask, + | ForceRegressionTask + | MultiTaskLitModule, transforms: list[AbstractDataTransform | Callable] | None = None, restart=None, label=None, @@ -87,7 +89,12 @@ def __init__( ) assert isinstance( task_module, - (ForceRegressionTask, ScalarRegressionTask, GradFreeForceRegressionTask), + ( + ForceRegressionTask, + ScalarRegressionTask, + GradFreeForceRegressionTask, + MultiTaskLitModule, + ), ), f"Expected task to be one that is capable of energy/force prediction. Got {task_module.__type__}." self.task_module = task_module self.transforms = transforms