Skip to content

Commit

Permalink
completed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cpelley committed Sep 26, 2024
1 parent 287852b commit 0038a4a
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 168 deletions.
12 changes: 10 additions & 2 deletions dagrunner/execute_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,22 @@ def plugin_executor(
callable_kwargs = {}
callable_kwargs_init = {}
else:
raise ValueError(f'expecting 1, 2 or 3 values to unpack for {callable_obj}, got {len(call)}')
raise ValueError(
f"expecting 1, 2 or 3 values to unpack for {callable_obj}, got {len(call)}"
)
callable_kwargs_init = (
{} if callable_kwargs_init is None else callable_kwargs_init
)
else:
if len(call) == 2:
_, callable_kwargs = call
elif len(call) == 1:
callable_kwargs = {}
else:
raise ValueError(f'expecting 1 or 2 values to unpack for {callable_obj}, got {len(call)}')
raise ValueError(
f"expecting 1 or 2 values to unpack for {callable_obj}, got {len(call)}"
)
callable_kwargs = {} if callable_kwargs is None else callable_kwargs

call_msg = ""
obj_name = callable_obj.__name__
Expand Down
79 changes: 33 additions & 46 deletions dagrunner/tests/execute_graph/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
import time
from dataclasses import dataclass
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -83,7 +82,7 @@ def graph(tmp_path_factory):
for nodenum in range(1, 6):
node = vars()[f"node{nodenum}"]
SETTINGS[node] = {
"call": tuple([ProcessID, {"id": nodenum}]),
"call": tuple([ProcessID, None, {"id": nodenum}]),
}

node_save = Node(step="save", leadtime=leadtime)
Expand All @@ -92,7 +91,7 @@ def graph(tmp_path_factory):
# we let SaveJson expand the filepath for us from the node properties (leadtime)
SETTINGS[node_save] = {
"call": tuple(
[SaveJson, {"filepath": f"{tmp_dir}/result_{{leadtime}}.json"}]
[SaveJson, None, {"filepath": f"{tmp_dir}/result_{{leadtime}}.json"}]
)
}
return EDGES, SETTINGS, output_files
Expand All @@ -115,16 +114,13 @@ def test_execution(graph, scheduler):
# parallel execution performance.
debug = False
EDGES, SETTINGS, output_files = graph
with patch("dagrunner.execute_graph.logger.ServerContext"), patch(
"dagrunner.execute_graph.logger.client_attach_socket_handler"
):
graph = ExecuteGraph(
(EDGES, SETTINGS),
num_workers=3,
scheduler=scheduler,
verbose=False,
debug=debug,
)()
graph = ExecuteGraph(
(EDGES, SETTINGS),
num_workers=3,
scheduler=scheduler,
verbose=False,
debug=debug,
)()
for output_file in output_files:
with open(output_file, "r") as file:
# two of them are expected since we have two leadtime branches
Expand All @@ -145,17 +141,14 @@ def test_skip_execution(graph):

# skip execution of the second branch
SETTINGS[Node(step="step2", leadtime=HOUR)] = {
"call": tuple([SkipExe, {"id": 2}]),
"call": tuple([SkipExe, None, {"id": 2}]),
}
with patch("dagrunner.execute_graph.logger.ServerContext"), patch(
"dagrunner.execute_graph.logger.client_attach_socket_handler"
):
graph = ExecuteGraph(
(EDGES, SETTINGS),
num_workers=3,
scheduler=scheduler,
verbose=False,
)()
graph = ExecuteGraph(
(EDGES, SETTINGS),
num_workers=3,
scheduler=scheduler,
verbose=False,
)()
output_file = output_files[0]
with open(output_file, "r") as file:
# two of them are expected since we have two leadtime branches
Expand All @@ -178,19 +171,16 @@ def test_multiprocessing_error_handling(graph):

# # skip execution of the second branch
SETTINGS[Node(step="step2", leadtime=HOUR)] = {
"call": tuple([RaiseErr, {"id": 2}]),
"call": tuple([RaiseErr, None, {"id": 2}]),
}
with patch("dagrunner.execute_graph.logger.ServerContext"), patch(
"dagrunner.execute_graph.logger.client_attach_socket_handler"
):
graph = ExecuteGraph(
(EDGES, SETTINGS),
num_workers=3,
scheduler=scheduler,
verbose=False,
)
with pytest.raises(RuntimeError, match="RaiseErr"):
graph()
graph = ExecuteGraph(
(EDGES, SETTINGS),
num_workers=3,
scheduler=scheduler,
verbose=False,
)
with pytest.raises(RuntimeError, match="RaiseErr"):
graph()


@pytest.fixture()
Expand All @@ -204,10 +194,10 @@ def graph_input():
EDGES.append([node1, node2])

SETTINGS[node1] = {
"call": tuple([Input, {"filepath": "{step}_{leadtime}"}]),
"call": tuple([Input, None, {"filepath": "{step}_{leadtime}"}]),
}
SETTINGS[node2] = {
"call": tuple([lambda x: x, {}]),
"call": tuple([lambda x: x]),
}
return EDGES, SETTINGS

Expand All @@ -219,14 +209,11 @@ def test_override_node_property_with_setting(graph_input, capsys):
new_step = "altered_step"
SETTINGS[Node(step="step1", leadtime=1)] |= {"step": new_step}

with patch("dagrunner.execute_graph.logger.ServerContext"), patch(
"dagrunner.execute_graph.logger.client_attach_socket_handler"
):
_ = ExecuteGraph(
(EDGES, SETTINGS),
num_workers=1,
scheduler=scheduler,
verbose=True,
)()
_ = ExecuteGraph(
(EDGES, SETTINGS),
num_workers=1,
scheduler=scheduler,
verbose=True,
)()
output = capsys.readouterr()
assert f"result: {new_step}_1" in output.out
179 changes: 116 additions & 63 deletions dagrunner/tests/execute_graph/test_plugin_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,98 +4,151 @@
# See LICENSE in the root of the repository for full licensing details.
from unittest import mock

import pytest

from dagrunner.execute_graph import plugin_executor


class DummyPlugin:
def __init__(self, iarg1, ikwarg1=None) -> None:
self._iarg1 = iarg1
self._ikwarg1 = ikwarg1
def __init__(self, init_named_arg, init_named_kwarg=None, **init_kwargs) -> None:
self._init_named_arg = init_named_arg
self._init_named_kwarg = init_named_kwarg
self._init_kwargs = init_kwargs

def __call__(self, *args, kwarg1=None, **kwargs):
def __call__(
self, *call_args, call_named_arg, call_named_kwarg=None, **call_kwargs
):
return (
f"iarg1={self._iarg1}; ikwarg1={self._ikwarg1}; args={args}; "
f"kwarg1={kwarg1}; kwargs={kwargs}"
f"init_kwargs={self._init_kwargs}; "
f"init_named_arg={self._init_named_arg}; init_named_kwarg={self._init_named_kwarg}; "
f"call_args={call_args}; call_kwargs={call_kwargs}; "
f"call_named_arg={call_named_arg}; call_named_kwarg={call_named_kwarg}; "
)


@mock.patch("dagrunner.execute_graph.logger.client_attach_socket_handler")
def test_pass_class_arg_kwargs(mock_client_attach_socket_handler):
"""Test passing named parameters to plugin class and __call__ method."""
args = (mock.sentinel.arg1, mock.sentinel.arg2)
call = tuple(
[
class DummyPluginNoNamedParam:
def __init__(self, **init_kwargs) -> None:
self._init_kwargs = init_kwargs

def __call__(self, *call_args, **call_kwargs):
return (
f"init_kwargs={self._init_kwargs}; "
f"call_args={call_args}; call_kwargs={call_kwargs}; "
)


@pytest.mark.parametrize(
"plugin, init_args, call_args, target",
[
# Passing class init and call args
(
DummyPlugin,
{"iarg1": mock.sentinel.iarg1, "ikwarg1": mock.sentinel.ikwarg1},
{"kwarg1": mock.sentinel.kwarg1},
]
)
{
"init_named_arg": mock.sentinel.init_named_arg,
"init_named_kwarg": mock.sentinel.init_named_kwarg,
"init_other_kwarg": mock.sentinel.init_other_kwarg,
},
{
"call_named_arg": mock.sentinel.call_named_arg,
"call_named_kwarg": mock.sentinel.call_named_kwarg,
"call_other_kwarg": mock.sentinel.call_other_kwarg,
},
(
"init_kwargs={'init_other_kwarg': sentinel.init_other_kwarg}; "
"init_named_arg=sentinel.init_named_arg; "
"init_named_kwarg=sentinel.init_named_kwarg; "
"call_args=(sentinel.arg1, sentinel.arg2); "
"call_kwargs={'call_other_kwarg': sentinel.call_other_kwarg}; "
"call_named_arg=sentinel.call_named_arg; call_named_kwarg=sentinel.call_named_kwarg; "
),
),
# Passing class init args only
(
DummyPluginNoNamedParam,
{"init_other_kwarg": mock.sentinel.init_other_kwarg},
None,
(
"init_kwargs={'init_other_kwarg': sentinel.init_other_kwarg}; "
"call_args=(sentinel.arg1, sentinel.arg2); "
"call_kwargs={}; "
),
),
# Passing class call args only
(
DummyPluginNoNamedParam,
None,
{"call_other_kwarg": mock.sentinel.call_other_kwarg},
(
"init_kwargs={}; "
"call_args=(sentinel.arg1, sentinel.arg2); "
"call_kwargs={'call_other_kwarg': sentinel.call_other_kwarg}; "
),
),
],
)
def test_pass_class_arg_kwargs(plugin, init_args, call_args, target):
"""Test passing named parameters to plugin class initialisation and __call__ method."""
args = (mock.sentinel.arg1, mock.sentinel.arg2)
call = tuple([plugin, init_args, call_args])
res = plugin_executor(*args, call=call)
assert res == (
"iarg1=sentinel.iarg1; ikwarg1=sentinel.ikwarg1; "
"args=(sentinel.arg1, sentinel.arg2); kwarg1=sentinel.kwarg1; "
"kwargs={}"
)
assert res == target


@mock.patch("dagrunner.execute_graph.logger.client_attach_socket_handler")
def test_pass_common_args(mock_client_attach_socket_handler):
def test_pass_common_args():
"""Passing 'common args', some relevant to class init and some to call method."""
args = (mock.sentinel.arg1, mock.sentinel.arg2)
common_kwargs = {
"ikwarg1": mock.sentinel.ikwarg1,
"kwarg1": mock.sentinel.kwarg1,
"iarg1": mock.sentinel.iarg1,
"init_named_arg": mock.sentinel.init_named_arg,
"init_named_kwarg": mock.sentinel.init_named_kwarg,
"other_kwargs": mock.sentinel.other_kwargs, # this should be ignored (as not part of class signature)
"call_named_arg": mock.sentinel.call_named_arg,
"call_named_kwarg": mock.sentinel.call_named_kwarg,
}

# call without common args (iarg1 is positional so non-optional)
call = tuple([DummyPlugin, {"iarg1": mock.sentinel.iarg1}, {}])
res = plugin_executor(*args, call=call)
assert res == (
"iarg1=sentinel.iarg1; ikwarg1=None; args=(sentinel.arg1, "
"sentinel.arg2); kwarg1=None; kwargs={}"
target = (
"init_kwargs={}; "
"init_named_arg=sentinel.init_named_arg; "
"init_named_kwarg=sentinel.init_named_kwarg; "
"call_args=(sentinel.arg1, sentinel.arg2); "
"call_kwargs={}; "
"call_named_arg=sentinel.call_named_arg; "
"call_named_kwarg=sentinel.call_named_kwarg; "
)

# call with common args
call = tuple([DummyPlugin, {}, {}])
res = plugin_executor(*args, call=call, common_kwargs=common_kwargs)
assert res == (
"iarg1=sentinel.iarg1; ikwarg1=sentinel.ikwarg1; "
"args=(sentinel.arg1, sentinel.arg2); kwarg1=sentinel.kwarg1; "
"kwargs={}"
)


class DummyPlugin2:
"""Plugin that is reliant on data not explicitly defined in its UI."""

def __call__(self, *args, **kwargs):
return f"args={args}; kwargs={kwargs}"
assert res == target


@mock.patch("dagrunner.execute_graph.logger.client_attach_socket_handler")
def test_pass_common_args_via_override(mock_client_attach_socket_handler):
"""
Passing 'common args' to a plugin that doesn't have such arguments
defined in its signature. Instead, filter out those that aren't
specified in the graph.
"""
def test_pass_common_args_override():
"""Passing 'common args', some relevant to class init and some to call method."""
args = (mock.sentinel.arg1, mock.sentinel.arg2)
common_kwargs = {
"kwarg1": mock.sentinel.kwarg1,
"kwarg2": mock.sentinel.kwarg2,
"kwarg3": mock.sentinel.kwarg3,
"init_named_arg": mock.sentinel.init_named_arg_override,
"init_named_kwarg": mock.sentinel.init_named_kwarg_override,
"call_named_arg": mock.sentinel.call_named_arg_override,
"call_named_kwarg": mock.sentinel.call_named_kwarg_override,
}
args = []
call = tuple(
[
DummyPlugin2,
DummyPlugin,
{
"kwarg1": mock.sentinel.kwarg1,
"kwarg2": mock.sentinel.kwarg2,
"init_named_arg": mock.sentinel.init_named_arg,
"init_named_kwarg": mock.sentinel.init_named_kwarg,
},
{
"call_named_arg": mock.sentinel.call_named_arg,
"call_named_kwarg": mock.sentinel.call_named_kwarg,
},
]
)
res = plugin_executor(*args, call=call, common_kwargs=common_kwargs)
assert (
res == "args=(); kwargs={'kwarg1': sentinel.kwarg1, 'kwarg2': sentinel.kwarg2}"
target = (
"init_kwargs={}; "
"init_named_arg=sentinel.init_named_arg_override; "
"init_named_kwarg=sentinel.init_named_kwarg_override; "
"call_args=(sentinel.arg1, sentinel.arg2); "
"call_kwargs={}; "
"call_named_arg=sentinel.call_named_arg_override; "
"call_named_kwarg=sentinel.call_named_kwarg_override; "
)
res = plugin_executor(*args, call=call, common_kwargs=common_kwargs)
assert res == target
Loading

0 comments on commit 0038a4a

Please sign in to comment.