Skip to content

Commit

Permalink
refactor: allowing task to be multi task
Browse files Browse the repository at this point in the history
  • Loading branch information
laserkelvin committed May 15, 2024
1 parent e8fdea8 commit affa5b4
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions matsciml/interfaces/ase/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ScalarRegressionTask,
GradFreeForceRegressionTask,
ForceRegressionTask,
MultiTaskLitModule,
)
from matsciml.datasets.transforms.base import AbstractDataTransform

Expand Down Expand Up @@ -74,7 +75,8 @@ def __init__(
self,
task_module: ScalarRegressionTask
| GradFreeForceRegressionTask
| ForceRegressionTask,
| ForceRegressionTask
| MultiTaskLitModule,
transforms: list[AbstractDataTransform | Callable] | None = None,
restart=None,
label=None,
Expand All @@ -87,7 +89,12 @@ def __init__(
)
assert isinstance(
task_module,
(ForceRegressionTask, ScalarRegressionTask, GradFreeForceRegressionTask),
(
ForceRegressionTask,
ScalarRegressionTask,
GradFreeForceRegressionTask,
MultiTaskLitModule,
),
), f"Expected task to be one that is capable of energy/force prediction. Got {task_module.__type__}."
self.task_module = task_module
self.transforms = transforms
Expand Down

0 comments on commit affa5b4

Please sign in to comment.