Skip to content

Commit

Permalink
WIP: testing how to integrate plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
agoscinski committed Oct 16, 2024
1 parent f2ad28e commit 6618dbb
Show file tree
Hide file tree
Showing 5 changed files with 470 additions and 71 deletions.
67 changes: 50 additions & 17 deletions src/wcflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

from datetime import datetime
from itertools import chain
from os.path import expandvars
from pathlib import Path
from typing import TYPE_CHECKING

from isoduration import parse_duration
from isoduration.types import Duration

from wcflow.parsing._utils import TimeUtils
from wcflow._utils import TimeUtils

if TYPE_CHECKING:
from collections.abc import Generator
Expand All @@ -24,20 +23,25 @@ def __init__(
lag: list[Duration],
date: list[datetime],
arg_option: str | None,
port_name: str | None,
):
self._name = name

self._src = src
self._path = Path(expandvars(self._src))

self._type = type
if self._type not in ["file", "dir"]:
msg = f"Data type {self._type!r} not supported. Please use 'file' or 'dir'."
raise ValueError(msg)

if len(lag) > 0 and len(date) > 0:
msg = "Either 'lag' or 'date' can be nonempty. Not both."
raise ValueError(msg)
# type probably not needed
#if self._type not in ["file", "dir", ""]:
# msg = f"Data type {self._type!r} not supported. Please use 'file' or 'dir'."
# raise ValueError(msg)

# check if not already done by parsing? best inherit these
#if len(lag) > 0 and len(date) > 0:
# msg = "Either 'lag' or 'date' can be nonempty. Not both."
# raise ValueError(msg)

self._port_name = self._name if port_name is None else port_name

# COMMENT I think we should just disallow multiple lags, and enforce the user to write multiple lags
# I am not sure how this work with icon as it does not need positional arguments
Expand Down Expand Up @@ -83,6 +87,10 @@ def date(self) -> list[datetime]:
def arg_option(self) -> str | None:
return self._arg_option

@property
def port_name(self) -> str:
return self._port_name


class Data(_DataBase):
def __init__(
Expand All @@ -93,8 +101,9 @@ def __init__(
lag: list[Duration],
date: list[datetime],
arg_option: str | None,
port_name: str | None
):
super().__init__(name, type, src, lag, date, arg_option)
super().__init__(name, type, src, lag, date, arg_option, port_name)
self._task: Task | None = None

def unroll(self, unrolled_task: UnrolledTask) -> Generator[UnrolledData, None, None]:
Expand Down Expand Up @@ -130,7 +139,6 @@ def task(self, task: Task):
raise ValueError(msg)
self._task = task


class UnrolledData(_DataBase):
"""
Data that are created during the unrolling of a cycle.
Expand All @@ -139,7 +147,7 @@ class UnrolledData(_DataBase):

@classmethod
def from_data(cls, data: Data, unrolled_task: UnrolledTask, unrolled_date: datetime):
return cls(unrolled_task, unrolled_date, data.name, data.type, data.src, data.lag, data.date, data.arg_option)
return cls(unrolled_task, unrolled_date, data.name, data.type, data.src, data.lag, data.date, data.arg_option, data.port_name)

def __init__(
self,
Expand All @@ -151,8 +159,9 @@ def __init__(
lag: list[Duration],
date: list[datetime],
arg_option: str | None,
port_name: str
):
super().__init__(name, type, src, lag, date, arg_option)
super().__init__(name, type, src, lag, date, arg_option, port_name)
self._unrolled_task = unrolled_task
self._unrolled_date = unrolled_date

Expand Down Expand Up @@ -323,13 +332,19 @@ def __init__(
outputs: list[Data],
depends: list[Dependency],
command_option: str | None,
plugin: str,
code: str,
computer: str,
):
self._name = name
self._command = expandvars(command)
self._command = command
self._inputs = inputs
self._outputs = outputs
self._depends = depends
self._command_option = command_option
self._plugin = plugin
self._code = code
self._computer = computer

@property
def name(self) -> str:
Expand All @@ -355,6 +370,18 @@ def command_option(self) -> str | None:
def depends(self) -> list[Dependency]:
return self._depends

@property
def plugin(self):
return self._plugin

@property
def code(self):
return self._code

@property
def computer(self):
return self._computer


class Task(_TaskBase):
"""A task that is created during the unrolling of a cycle."""
Expand All @@ -367,8 +394,11 @@ def __init__(
outputs: list[Data],
depends: list[Dependency],
command_option: str | None,
plugin: str,
code: str,
computer: str,
):
super().__init__(name, command, inputs, outputs, depends, command_option)
super().__init__(name, command, inputs, outputs, depends, command_option, plugin, code, computer)
for input_ in inputs:
input_.task = self
for output in outputs:
Expand Down Expand Up @@ -409,7 +439,7 @@ class UnrolledTask(_TaskBase):
@classmethod
def from_task(cls, task: Task, unrolled_cycle: UnrolledCycle):
return cls(
unrolled_cycle, task.name, task.command, task.inputs, task.outputs, task.depends, task.command_option
unrolled_cycle, task.name, task.command, task.inputs, task.outputs, task.depends, task.command_option, task.plugin, task.code, task.computer
)

def __init__(
Expand All @@ -421,8 +451,11 @@ def __init__(
outputs: list[Data],
depends: list[Dependency],
command_option: str | None,
plugin: str,
code: str,
computer: str,
):
super().__init__(name, command, inputs, outputs, depends, command_option)
super().__init__(name, command, inputs, outputs, depends, command_option, plugin, code, computer)
self._unrolled_cycle = unrolled_cycle
self._unrolled_inputs = list(self.unroll_inputs())
self._unrolled_outputs = list(self.unroll_outputs())
Expand Down
61 changes: 35 additions & 26 deletions src/wcflow/parsing/_yaml_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,17 +83,14 @@ class ConfigTask(_NamedBaseModel):
To create an instance of a task defined in a workflow file
"""

command: str
command_option: str | None = None
host: str | None = None
account: str | None = None
# aiida-shell specifics
command: str | None = None
# these are global commands_options, I think we can remove arg_options to simplify
command_options: str | None = None
# general
plugin: str | None = None
config: str | None = None
uenv: dict | None = None
nodes: int | None = None
walltime: str | None = None
src: str | None = None
conda_env: str | None = None
computer: str | None = None
code: str | None = None

def __init__(self, /, **data):
# We have to treat root special as it does not typically define a command
Expand All @@ -107,31 +104,37 @@ def expand_env_vars(cls, value: str) -> str:
"""Expands any environment variables in the value"""
return expandvars(value)

@field_validator("walltime")
@classmethod
def convert_to_struct_time(cls, value: str | None) -> time.struct_time | None:
"""Converts a string of form "%H:%M:%S" to a time.time_struct"""
return None if value is None else time.strptime(value, "%H:%M:%S")


class ConfigData(_NamedBaseModel):
"""
To create an instance of a data defined in a workflow file.
"""

type: str
src: str
type: str | None = None
src: str | int | dict | None = None
format: str | None = None

@field_validator("type")
@classmethod
def is_file_or_dir(cls, value: str) -> str:
"""."""
if value not in ["file", "dir"]:
msg = "Must be one of 'file' or 'dir'."
raise ValueError(msg)
# Is this actually needed? We can refer everything from the plugin
#if value not in ["file", "dir", "int"]:
# msg = "Must be one of 'file' or 'dir'."
# raise ValueError(msg)
return value

@field_validator("src")
@classmethod
def expand_env_vars(cls, value: str | int | dict | None) -> str | int | dict | None:
"""Expands any environment variables in the value"""
if isinstance(value, str):
return expandvars(value)
elif isinstance(value, dict):
raise NotImplementedError()
else:
return value


class ConfigCycleTaskDepend(_NamedBaseModel, _LagDateBaseModel):
"""
Expand All @@ -142,6 +145,7 @@ class ConfigCycleTaskDepend(_NamedBaseModel, _LagDateBaseModel):
cycle_name: str | None = None


# NOT NEEDED, but can we make dict with arbitrary
class ConfigCycleTaskInput(_NamedBaseModel, _LagDateBaseModel):
"""
To create an instance of an input in a task in a cycle defined in a workflow file.
Expand All @@ -156,13 +160,15 @@ class ConfigCycleTaskInput(_NamedBaseModel, _LagDateBaseModel):
"""

arg_option: str | None = None
# lag?
port_name: str | None = None


class ConfigCycleTaskOutput(_NamedBaseModel):
"""
To create an instance of an output in a task in a cycle defined in a workflow file.
"""

port_name: str | None = None

class ConfigCycleTask(_NamedBaseModel):
"""
Expand All @@ -175,7 +181,7 @@ class ConfigCycleTask(_NamedBaseModel):

@field_validator("inputs", mode="before")
@classmethod
def convert_cycle_task_inputs(cls, values) -> list[ConfigCycleTaskInput]:
def convert_cycle_task_inputs(cls, values) -> list[dict]:
inputs = []
if values is None:
return inputs
Expand Down Expand Up @@ -295,14 +301,14 @@ def _to_core_task(self, cycle_task: ConfigCycleTask) -> core.Task:
if (data := self.data_dict.get(input_.name)) is None:
msg = f"Task {cycle_task.name!r} has input {input_.name!r} that is not specied in the data section."
raise ValueError(msg)
core_data = core.Data(input_.name, data.type, data.src, input_.lag, input_.date, input_.arg_option)
core_data = core.Data(input_.name, data.type, data.src, input_.lag, input_.date, input_.arg_option, input_.port_name)
inputs.append(core_data)

for output in cycle_task.outputs:
if (data := self.data_dict.get(output.name)) is None:
msg = f"Task {cycle_task.name!r} has output {output.name!r} that is not specied in the data section."
raise ValueError(msg)
core_data = core.Data(output.name, data.type, data.src, [], [], None)
core_data = core.Data(output.name, data.type, data.src, [], [], None, output.port_name)
outputs.append(core_data)

for depend in cycle_task.depends:
Expand All @@ -315,7 +321,10 @@ def _to_core_task(self, cycle_task: ConfigCycleTask) -> core.Task:
inputs,
outputs,
dependencies,
self.task_dict[cycle_task.name].command_option,
self.task_dict[cycle_task.name].command_options,
self.task_dict[cycle_task.name].plugin,
self.task_dict[cycle_task.name].code,
self.task_dict[cycle_task.name].computer,
)


Expand Down
Loading

0 comments on commit 6618dbb

Please sign in to comment.