Skip to content

Commit

Permalink
checkpoint for the changes on yml
Browse files Browse the repository at this point in the history
  • Loading branch information
agoscinski committed Dec 9, 2024
1 parent ba9f9ef commit af7e4e4
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 23 deletions.
4 changes: 2 additions & 2 deletions src/sirocco/core/_tasks/shell_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass
from dataclasses import field, dataclass
from typing import ClassVar

from sirocco.core.graph_items import Task
Expand All @@ -12,5 +12,5 @@ class ShellTask(Task):

command: str | None = None
command_option: str | None = None
input_arg_options: dict[str, str] | None = None
input_arg_options: dict[str, str] = field(default_factory=dict)
src: str | None = None
7 changes: 7 additions & 0 deletions src/sirocco/core/graph_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from dataclasses import dataclass, field
from itertools import chain, product
from os.path import expandvars
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Self

if TYPE_CHECKING:
Expand Down Expand Up @@ -116,6 +118,11 @@ def from_config(cls, config: DataBaseModel, coordinates: dict) -> Self:
coordinates=coordinates,
)

@property
def path(self) -> Path:
# TODO yaml level?
return Path(expandvars(self.src))


@dataclass(kw_only=True)
class Cycle(GraphItem):
Expand Down
52 changes: 40 additions & 12 deletions src/sirocco/parsing/_yaml_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,30 @@ class _NamedBaseModel(BaseModel):
name: str

def __init__(self, /, **data):
super().__init__(**self.merge_name_and_specs(data))

@staticmethod
def merge_name_and_specs(data: dict) -> dict:
"""
Converts dict of form
`{my_name: {'spec_0': ..., ..., 'spec_n': ...}`
to
`{'name': my_name, 'spec_0': ..., ..., 'spec_n': ...}`
by copy.
"""
name_and_spec = {}
# - my_name:
# ...
if len(data) != 1:
msg = f"Expected dict with one element of the form {{'name': specification}} but got {data}."
raise ValueError(msg)
name_and_spec["name"] = next(iter(data.keys()))
# if no specification specified e.g. "- my_name:"
if (spec := next(iter(data.values()))) is not None:
name_and_spec.update(spec)

super().__init__(**name_and_spec)
return name_and_spec


class _WhenBaseModel(BaseModel):
Expand Down Expand Up @@ -228,14 +240,11 @@ def check_period_is_not_negative_or_zero(self) -> ConfigCycle:
raise ValueError(msg)
return self

from typing import Literal

class ConfigTask(_NamedBaseModel):
"""
To create an instance of a task defined in a workflow file
"""

class ConfigTaskBase(_NamedBaseModel):
# config for genric task, no plugin specifics
parameters: list[str] = []
parameters: list[str] = Field(default_factory=list)
host: str | None = None
account: str | None = None
plugin: str | None = None
Expand All @@ -256,7 +265,17 @@ def convert_to_struct_time(cls, value: str | None) -> time.struct_time | None:
return None if value is None else time.strptime(value, "%H:%M:%S")


# TODO(maybe): ConfigTaskIcon(ConfigTask) and ConfigTaskShell(ConfigTask)

class ConfigTaskShell(ConfigTaskBase):
plugin: Literal["shell"]
command: str
command_option: str = ""
input_arg_options: dict[str, str] = Field(default_factory=dict)
src: str | None = None


class ConfigTaskIcon(ConfigTaskBase):
plugin: Literal["icon"]


class DataBaseModel(_NamedBaseModel):
Expand Down Expand Up @@ -297,11 +316,20 @@ class ConfigData(BaseModel):
available: list[ConfigAvailableData] = []
generated: list[ConfigGeneratedData] = []

from typing import Annotated
from pydantic import Discriminator, Tag

def get_plugin_from_named_base_model(data: dict) -> str:
plugin =_NamedBaseModel.merge_name_and_specs(data).get("plugin", "")
return plugin

class ConfigWorkflow(BaseModel):
name: str | None = None
cycles: list[ConfigCycle]
tasks: list[ConfigTask]
tasks: list[Annotated[
Annotated[ConfigTaskIcon, Tag("icon")] |
Annotated[ConfigTaskShell, Tag("shell")],
Discriminator(get_plugin_from_named_base_model)]]
data: ConfigData
parameters: dict[str, list] = {}
data_dict: dict = {}
Expand Down
16 changes: 9 additions & 7 deletions src/sirocco/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self, core_workflow: core.Workflow):

def _validate_workflow(self):
"""Checks if the defined workflow is correctly referencing key names."""
for cycle in self._core_workflow.cycles.values():
for cycle in self._core_workflow.cycles:
try:
aiida.common.validate_link_label(cycle.name)
except ValueError as exception:
Expand Down Expand Up @@ -109,7 +109,7 @@ def _add_aiida_initial_data_nodes(self):
"""
Nodes that correspond to data that are available on initialization of the workflow
"""
for cycle in self._core_workflow.cycles.values():
for cycle in self._core_workflow.cycles:
for task in cycle.tasks:
for input_ in task.inputs:
if input_.available:
Expand All @@ -126,7 +126,7 @@ def parse_to_aiida_label(label: str) -> str:
def get_aiida_label_from_unrolled_data(obj: core.BaseNode) -> str:
""" """
return AiidaWorkGraph.parse_to_aiida_label(
f"{obj.name}" + "_".join(f"_{key}_{value}" for key, value in obj.parameters.items())
f"{obj.name}" + "_".join(f"_{key}_{value}" for key, value in obj.coordinates.items())
)

@staticmethod
Expand All @@ -137,7 +137,7 @@ def get_aiida_label_from_unrolled_task(obj: core.BaseNode) -> str:
# Otherwise the label is not unique
# --> task name + date + parameters
return AiidaWorkGraph.parse_to_aiida_label(
f"{obj.name}" + "_".join(f"_{key}_{value}" for key, value in obj.parameters.items())
f"{obj.name}" + "_".join(f"_{key}_{value}" for key, value in obj.coordinates.items())
)

def _add_aiida_input_data_node(self, input_: core.UnrolledData):
Expand All @@ -156,17 +156,19 @@ def _add_aiida_input_data_node(self, input_: core.UnrolledData):
raise ValueError(msg)

def _add_aiida_task_nodes(self):
for cycle in self._core_workflow.cycles.values():
for cycle in self._core_workflow.cycles:
for task in cycle.tasks:
self._add_aiida_task_node(task)
# after creation we can link the wait_on tasks
# TODO check where this is now
#for cycle in self._core_workflow.cycles.values():
#for cycle in self._core_workflow.cycles:
# for task in cycle.tasks:
# self._link_wait_on_to_task(task)

def _add_aiida_task_node(self, task: core.UnrolledTask):
label = AiidaWorkGraph.get_aiida_label_from_unrolled_task(task)
if task.command is None:
raise ValueError(f"The command is None of task {task}.")
workgraph_task = self._workgraph.tasks.new(
"ShellJob",
name=label,
Expand All @@ -189,7 +191,7 @@ def _link_wait_on_to_task(self, task: core.UnrolledTask):
workgraph_task.wait = wait_on_tasks

def _add_aiida_links(self):
for cycle in self._core_workflow.cycles.values():
for cycle in self._core_workflow.cycles:
self._add_aiida_links_from_cycle(cycle)

def _add_aiida_links_from_cycle(self, cycle: core.UnrolledCycle):
Expand Down
4 changes: 2 additions & 2 deletions tests/files/configs/test_config_parameters.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,5 @@ data:
src: yearly_analysis

parameters:
foo: [0, 1]
bar: {arange: [3, 4]}
foo: [0, 1, 2]
bar: [3.0, 3.5]

0 comments on commit af7e4e4

Please sign in to comment.