Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to support executing EPP configurations #56

Merged
merged 11 commits into from
Oct 3, 2024
2 changes: 1 addition & 1 deletion dagrunner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
# See LICENSE in the root of the repository for full licensing details.
from .plugin_framework import DataPolling, Input, NodeAwarePlugin, Plugin, Shell

__all__ = ["Plugin", "Shell", "Input", "DataPolling", "NodeAwarePlugin"]
__all__ = ["Plugin", "Shell", "Input", "DataPolling", "NodeAwarePlugin", "Load"]

__version__ = "0.0.1dev"
68 changes: 37 additions & 31 deletions dagrunner/execute_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from dagrunner.utils import (
CaptureProcMemory,
TimeIt,
function_to_argparse,
function_to_argparse_parse_args,
logger,
)
from dagrunner.utils.visualisation import visualise_graph
Expand Down Expand Up @@ -139,19 +139,7 @@ def plugin_executor(
print(f"Skipping node {call[0]}")
return SKIP_EVENT

# Handle call tuple unpacking (length 2, no class init kwargs
# or length 3 with class init kwargs).
try:
callable_obj, callable_kwargs_init, callable_kwargs = call
except ValueError as e:
if (
str(e) == "not enough values to unpack (expected 3, got 2)"
): # no class init kwargs
callable_obj, callable_kwargs = call
callable_kwargs_init = {}
else:
raise e

callable_obj = call[0]
if isinstance(callable_obj, str):
# import callable if a string is provided
module_name, function_name = callable_obj.rsplit(".", 1)
Expand All @@ -160,6 +148,37 @@ def plugin_executor(
print(f"imported module '{module}', callable '{function_name}'")
callable_obj = getattr(module, function_name)

# Handle call tuple unpacking (length 2, no class init kwargs
# or length 3 with class init kwargs).
if isinstance(callable_obj, type):
if len(call) == 3:
_, callable_kwargs_init, callable_kwargs = call
elif len(call) == 2:
_, callable_kwargs_init = call
callable_kwargs = {}
elif len(call) == 1:
callable_kwargs = {}
callable_kwargs_init = {}
else:
raise ValueError(
f"expecting 1, 2 or 3 values to unpack for {callable_obj}, "
f"got {len(call)}"
)
callable_kwargs_init = (
{} if callable_kwargs_init is None else callable_kwargs_init
)
else:
if len(call) == 2:
_, callable_kwargs = call
elif len(call) == 1:
callable_kwargs = {}
else:
raise ValueError(
f"expecting 1 or 2 values to unpack for {callable_obj}, got "
f"{len(call)}"
)
callable_kwargs = {} if callable_kwargs is None else callable_kwargs

call_msg = ""
obj_name = callable_obj.__name__
if isinstance(callable_obj, type):
Expand Down Expand Up @@ -230,8 +249,6 @@ def _get_networkx(networkx_graph):
module = importlib.import_module(".".join(parts[:-1]))
networkx_graph = parts[-1]
nxgraph = getattr(module, networkx_graph)
elif callable(networkx_graph):
nxgraph = networkx_graph()
else:
try:
edges, nodes = networkx_graph
Expand All @@ -257,7 +274,6 @@ def __init__(
profiler_filepath: str = None,
dry_run: bool = False,
verbose: bool = False,
sqlite_filepath: str = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -288,8 +304,6 @@ def __init__(
Optional.
- `verbose` (bool):
Print executed commands. Optional.
- `sqlite_filepath` (str):
Filepath to a SQLite database to store log records. Optional.
- `**kwargs`:
Optional global keyword arguments to apply to all applicable plugins.
"""
Expand All @@ -306,7 +320,6 @@ def __init__(
self._profiler_output = profiler_filepath
self._kwargs = kwargs | {"verbose": verbose, "dry_run": dry_run}
self._exec_graph = self._process_graph()
self._sqlite_filepath = sqlite_filepath

@property
def nxgraph(self):
Expand Down Expand Up @@ -348,9 +361,7 @@ def visualise(self, output_filepath: str):
_attempt_visualise_graph(self._exec_graph, output_filepath)

def __call__(self):
with logger.ServerContext(sqlite_filepath=self._sqlite_filepath), TimeIt(
verbose=True
), self._scheduler(
with TimeIt(verbose=True), self._scheduler(
self._num_workers, profiler_filepath=self._profiler_output
) as scheduler:
try:
Expand All @@ -365,14 +376,9 @@ def main():
Entry point of the program.
Parses command line arguments and executes the graph using the ExecuteGraph class.
"""
parser = function_to_argparse(ExecuteGraph, exclude=["plugin_executor"])
args = parser.parse_args()
args = vars(args)
# positional arguments with '-' aren't converted to '_' by argparse.
args = {key.replace("-", "_"): value for key, value in args.items()}
if args.get("verbose", False):
print(f"CLI call arguments: {args}")
kwargs = args.pop("kwargs", None) or {}
args, kwargs = function_to_argparse_parse_args(
ExecuteGraph, exclude=["plugin_executor"]
)
ExecuteGraph(**args, **kwargs)()


Expand Down
110 changes: 69 additions & 41 deletions dagrunner/plugin_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This file is part of 'dagrunner' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.
import itertools
import json
import os
import pickle
Expand Down Expand Up @@ -179,61 +180,88 @@ def __call__(
- timeout (int): Timeout in seconds (default is 120 seconds).
- polling (int): Time interval in seconds between each poll (default is 1
second).
- file_count (int): Expected number of files to be found (default is None).
If specified, the total number of files found can be greater than the
number of arguments. Each argument is expected to return a minimum of
1 match each in either case.
- file_count (int): Expected number of files to be found for globular
expansion (default is >= 1 files per pattern).

Returns:
- None

Raises:
- RuntimeError: If the timeout is reached before all files are found.
"""
time_taken = indx = patterns_found = files_found = 0
fpaths_found = []
file_count = len(args) if file_count is None else max(file_count, len(args))

# Define a key function
def host_and_glob_key(path):
psplit = path.split(":")
host = psplit[0] if ":" in path else "" # Extract host if available
is_glob = psplit[-1] if "*" in psplit[-1] else "" # Glob pattern
return (host, is_glob)

time_taken = 0
fpaths_found = set()
args = list(map(process_path, args))
while time_taken < timeout:
pattern = args[indx]
host = None
if ":" in pattern:
host, pattern = pattern.split(":")

if host:
# bash equivalent to python glob (glob on remote host)
expanded_paths = subprocess.run(
f'ssh {host} \'for file in {pattern}; do if [ -e "$file" ]; then '
'echo "$file"; fi; done\'',
shell=True,
check=True,
text=True,
capture_output=True,
).stdout.strip()
# Group by host and whether it's a glob pattern
sorted_args = sorted(args, key=host_and_glob_key)
args_by_host = [
[key, set(map(lambda path: path.split(":")[-1], group))]
for key, group in itertools.groupby(sorted_args, key=host_and_glob_key)
]

for ind, ((host, globular), paths) in enumerate(args_by_host):
globular = bool(globular)
host_msg = f"{host}:" if host else ""
while time_taken < timeout or not timeout:
if host:
# bash equivalent to python glob (glob on remote host)
expanded_paths = subprocess.run(
f'ssh {host} \'for file in {" ".join(paths)}; do if '
'[ -e "$file" ]; then echo "$file"; fi; done\'',
shell=True,
check=True,
text=True,
capture_output=True,
).stdout.strip()
if expanded_paths:
expanded_paths = expanded_paths.split("\n")
else:
expanded_paths = list(
itertools.chain.from_iterable(map(glob, paths))
)
if expanded_paths:
expanded_paths = expanded_paths.split("\n")
else:
expanded_paths = glob(pattern)
if expanded_paths:
fpaths_found.extend(expanded_paths)
patterns_found += 1
files_found += len(expanded_paths)
indx += 1
elif verbose:
print(
f"polling for '{pattern}', time taken: {time_taken}s of limit "
f"{timeout}s"
fpaths_found = fpaths_found.union(expanded_paths)
if globular and (
not file_count or len(expanded_paths) >= file_count
):
# globular expansion completed
paths = set()
else:
# Remove paths we have found
paths = paths - set(expanded_paths)

if paths:
if timeout:
print(
f"polling for {host_msg}{paths}, time taken: "
f"{time_taken}s of limit {timeout}s"
)
time.sleep(polling)
time_taken += polling
else:
break
else:
break

if paths:
raise FileNotFoundError(
f"Timeout waiting for: {host_msg}{'; '.join(sorted(paths))}"
)
if patterns_found >= len(args) or files_found >= file_count:
break
time.sleep(polling)
time_taken += polling

if patterns_found < len(args):
raise FileNotFoundError(f"Timeout waiting for: '{pattern}'")
if verbose and fpaths_found:
print(f"The following files were polled and found: {fpaths_found}")
print(
"The following files were polled and found: "
f"{'; '.join(sorted(fpaths_found))}"
)
return None


Expand Down
Loading