Skip to content

Commit

Permalink
First version to make to_core_* methods work
Browse files Browse the repository at this point in the history
Starting from the pydantic config representations. Ideally, I think one
should rather have `from_config_*` methods in `core.py`, rather than
`to_core_*` methods in `_yaml_data_models.py`, but chose this route as
it's simpler to adapt for now.
  • Loading branch information
GeigerJ2 committed Oct 15, 2024
1 parent cfe036a commit 9daed69
Showing 1 changed file with 88 additions and 12 deletions.
100 changes: 88 additions & 12 deletions src/wcflow/parsing/_yaml_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from datetime import datetime
from os.path import expandvars
from pathlib import Path
from typing import Any
from typing import Any, List, Optional

from isoduration import parse_duration
from isoduration.types import Duration # pydantic needs type # noqa: TCH002
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, field_validator, model_validator, Field
from wcflow import core

from ._utils import TimeUtils

from rich.pretty import pprint

class _NamedBaseModel(BaseModel):
"""Base class for all classes with a key that specifies their name.
Expand Down Expand Up @@ -167,9 +168,9 @@ class ConfigCycleTask(_NamedBaseModel):
To create an instance of a task in a cycle defined in a workflow file.
"""

inputs: list[ConfigCycleTaskInput | str] | None = None
outputs: list[ConfigCycleTaskOutput | str] | None = None
depends: list[ConfigCycleTaskDepend | str] | None = None
inputs: Optional[List[ConfigCycleTaskInput | str]] = Field(default_factory=list)
outputs: Optional[List[ConfigCycleTaskOutput | str]] = Field(default_factory=list)
depends: Optional[List[ConfigCycleTaskDepend | str]] = Field(default_factory=list)

@field_validator("inputs", mode="before")
def convert_cycle_task_inputs(values) -> list[ConfigCycleTaskInput]:
Expand All @@ -195,6 +196,22 @@ def convert_cycle_task_outputs(values) -> list[ConfigCycleTaskOutput]:
outputs.append(value)
return outputs

@field_validator("depends", mode="before")
def convert_cycle_task_depends(values) -> list[ConfigCycleTaskDepend]:
depends = []
print(f'VALUES: {values}')
if values is None:
return depends
for value in values:
if isinstance(value, str):
depends.append({value: None})
elif isinstance(value, dict):
depends.append(value)

return depends



class ConfigCycle(_NamedBaseModel):
"""
To create an instance of a cycle defined in a workflow file.
Expand All @@ -218,15 +235,21 @@ def convert_duration(cls, value):
return None if value is None else parse_duration(value)

@model_validator(mode="after")
def check_start_date_before_end_date(self) -> 'ConfigCycle':
if self.start_date is not None and self.end_date is not None and self.start_date > self.end_date:
def check_start_date_before_end_date(self) -> "ConfigCycle":
if (
self.start_date is not None
and self.end_date is not None
and self.start_date > self.end_date
):
msg = "For cycle {self._name!r} the start_date {start_date!r} lies after given end_date {end_date!r}."
raise ValueError(msg)
return self

@model_validator(mode="after")
def check_period_is_not_negative_or_zero(self) -> 'ConfigCycle':
if self.period is not None and TimeUtils.duration_is_less_equal_zero(self.period):
def check_period_is_not_negative_or_zero(self) -> "ConfigCycle":
if self.period is not None and TimeUtils.duration_is_less_equal_zero(
self.period
):
msg = f"For cycle {self.name!r} the period {self.period!r} is negative or zero."
raise ValueError(msg)
return self
Expand All @@ -239,19 +262,72 @@ class ConfigWorkflow(BaseModel):
cycles: list[ConfigCycle]
tasks: list[ConfigTask]
data: list[ConfigData]
data_dict: dict = {}
task_dict: dict = {}

@field_validator("start_date", "end_date", mode="before")
@classmethod
def convert_datetime(cls, value) -> None | datetime:
return None if value is None else datetime.fromisoformat(value)

@model_validator(mode="after")
def check_start_date_before_end_date(self) -> 'ConfigWorkflow':
if self.start_date is not None and self.end_date is not None and self.start_date > self.end_date:
def check_start_date_before_end_date(self) -> "ConfigWorkflow":
if (
self.start_date is not None
and self.end_date is not None
and self.start_date > self.end_date
):
msg = "For workflow {self._name!r} the start_date {start_date!r} lies after given end_date {end_date!r}."
raise ValueError(msg)
return self

def to_core_workflow(self):

self.data_dict = {data.name: data for data in self.data}
self.task_dict = {task.name: task for task in self.tasks}

core_cycles = [self._to_core_cycle(cycle) for cycle in self.cycles]
return core.Workflow(self.name, core_cycles)

def _to_core_cycle(self, cycle: ConfigCycle) -> core.Cycle:
core_tasks = [self._to_core_task(task) for task in cycle.tasks]
start_date = self.start_date if cycle.start_date is None else cycle.start_date
end_date = self.end_date if cycle.end_date is None else cycle.end_date
return core.Cycle(cycle.name, core_tasks, start_date, end_date, cycle.period)

def _to_core_task(self, cycle_task: ConfigCycleTask) -> core.Task:

inputs = []
outputs = []
dependencies = []

for input_ in cycle_task.inputs:
if (data := self.data_dict.get(input_.name)) is None:
msg = f"Task {cycle_task.name!r} has input {input_.name!r} that is not specied in the data section."
raise ValueError(msg)
core_data = core.Data(input_.name, data.type, data.src, input_.lag, input_.date, input_.arg_option)
inputs.append(core_data)

for output in cycle_task.outputs:
if (data := self.data_dict.get(output.name)) is None:
msg = f"Task {cycle_task.name!r} has output {output.name!r} that is not specied in the data section."
raise ValueError(msg)
core_data = core.Data(output.name, data.type, data.src, [], [], None)
outputs.append(core_data)

for depend in cycle_task.depends:
core_dependency = core.Dependency(depend.name, depend.lag, depend.date, depend.cycle_name)
dependencies.append(core_dependency)

return core.Task(
cycle_task.name,
self.task_dict[cycle_task.name].command,
inputs,
outputs,
dependencies,
self.task_dict[cycle_task.name].command_option,
)


def load_workflow_config(workflow_config: str) -> ConfigWorkflow:
"""
Expand Down

0 comments on commit 9daed69

Please sign in to comment.