Skip to content

Commit

Permalink
Added more testing and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
JosePizarro3 committed Oct 2, 2024
1 parent a3b95bf commit d2d57ff
Show file tree
Hide file tree
Showing 3 changed files with 349 additions and 19 deletions.
38 changes: 23 additions & 15 deletions src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,15 @@
from nomad.datamodel.datamodel import EntryArchive
from structlog.stdlib import BoundLogger

from nomad_simulations.schema_packages.workflow import SinglePoint

from nomad.datamodel.metainfo.workflow import Link, TaskReference
from nomad.metainfo import Quantity, Reference

from nomad_simulations.schema_packages.model_method import DFT, TB, ModelMethod
from nomad_simulations.schema_packages.workflow import (
BeyondDFT,
BeyondDFTMethod,
)
from nomad_simulations.schema_packages.model_method import DFT, TB
from nomad_simulations.schema_packages.workflow import BeyondDFT, BeyondDFTMethod
from nomad_simulations.schema_packages.workflow.base_workflows import check_n_tasks

from .single_point import SinglePoint


class DFTPlusTBMethod(BeyondDFTMethod):
"""
Expand Down Expand Up @@ -106,7 +103,15 @@ class DFTPlusTB(BeyondDFT):
"""

@check_n_tasks(n_tasks=2)
def link_task_inputs_outputs(self, tasks: list[TaskReference]) -> None:
def link_task_inputs_outputs(
self, tasks: list[TaskReference], logger: 'BoundLogger'
) -> None:
if not self.inputs or not self.outputs:
logger.warning(
'The `DFTPlusTB` workflow needs to have `inputs` and `outputs` defined in order to link with the `tasks`.'
)
return None

dft_task = tasks[0]
tb_task = tasks[1]

Expand Down Expand Up @@ -144,7 +149,7 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
'A `DFTPlusTB` workflow must have two `SinglePoint` tasks references.'
)
return
if not isinstance(task.task, 'SinglePoint'):
if not isinstance(task.task, SinglePoint):
logger.error(
'The referenced tasks in the `DFTPlusTB` workflow must be of type `SinglePoint`.'
)
Expand All @@ -158,11 +163,14 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
tasks=self.tasks,
tasks_names=['DFT SinglePoint Task', 'TB SinglePoint Task'],
)
if method_refs is not None and len(method_refs) == 2:
self.method = DFTPlusTBMethod(
dft_method_ref=method_refs[0],
tb_method_ref=method_refs[1],
)
if method_refs is not None:
method_workflow = DFTPlusTBMethod()
for method in method_refs:
if isinstance(method, DFT):
method_workflow.dft_method_ref = method
elif isinstance(method, TB):
method_workflow.tb_method_ref = method
self.method = method_workflow

# Resolve `tasks[*].inputs` and `tasks[*].outputs`
self.link_task_inputs_outputs(tasks=self.tasks)
self.link_task_inputs_outputs(tasks=self.tasks, logger=logger)
6 changes: 6 additions & 0 deletions tests/workflow/test_base_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,16 @@ def test_resolve_all_outputs(self, tasks: list[Task], result: list[Outputs]):
[
# no task
(None, None),
# only one task
([TaskReference()], []),
# two empty tasks
([TaskReference(), TaskReference()], []),
# two tasks with only empty task
(
[TaskReference(task=SinglePoint()), TaskReference(task=SinglePoint())],
[],
),
# two tasks with task with one input ModelSystem each
(
[
TaskReference(
Expand All @@ -101,6 +105,7 @@ def test_resolve_all_outputs(self, tasks: list[Task], result: list[Outputs]):
],
[],
),
# two tasks with task with one input ModelSystem each and only DFT input
(
[
TaskReference(
Expand All @@ -121,6 +126,7 @@ def test_resolve_all_outputs(self, tasks: list[Task], result: list[Outputs]):
],
[DFT],
),
# two tasks with task with inputs for ModelSystem and DFT and TB
(
[
TaskReference(
Expand Down
Loading

0 comments on commit d2d57ff

Please sign in to comment.