From 0038a4a3f0d3f10e7480c7377cbb9d680f2e3bdb Mon Sep 17 00:00:00 2001 From: cpelley Date: Thu, 26 Sep 2024 16:45:19 +0100 Subject: [PATCH] completed tests --- dagrunner/execute_graph.py | 12 +- .../tests/execute_graph/test_integration.py | 79 ++++---- .../execute_graph/test_plugin_executor.py | 179 ++++++++++++------ .../tests/utils/logging/test_integration.py | 58 ++++-- dagrunner/utils/logger.py | 88 +++++---- 5 files changed, 248 insertions(+), 168 deletions(-) diff --git a/dagrunner/execute_graph.py b/dagrunner/execute_graph.py index 9a06349..3d89abc 100755 --- a/dagrunner/execute_graph.py +++ b/dagrunner/execute_graph.py @@ -160,14 +160,22 @@ def plugin_executor( callable_kwargs = {} callable_kwargs_init = {} else: - raise ValueError(f'expecting 1, 2 or 3 values to unpack for {callable_obj}, got {len(call)}') + raise ValueError( + f"expecting 1, 2 or 3 values to unpack for {callable_obj}, got {len(call)}" + ) + callable_kwargs_init = ( + {} if callable_kwargs_init is None else callable_kwargs_init + ) else: if len(call) == 2: _, callable_kwargs = call elif len(call) == 1: callable_kwargs = {} else: - raise ValueError(f'expecting 1 or 2 values to unpack for {callable_obj}, got {len(call)}') + raise ValueError( + f"expecting 1 or 2 values to unpack for {callable_obj}, got {len(call)}" + ) + callable_kwargs = {} if callable_kwargs is None else callable_kwargs call_msg = "" obj_name = callable_obj.__name__ diff --git a/dagrunner/tests/execute_graph/test_integration.py b/dagrunner/tests/execute_graph/test_integration.py index cc52493..72dfa5b 100644 --- a/dagrunner/tests/execute_graph/test_integration.py +++ b/dagrunner/tests/execute_graph/test_integration.py @@ -6,7 +6,6 @@ import os import time from dataclasses import dataclass -from unittest.mock import patch import pytest @@ -83,7 +82,7 @@ def graph(tmp_path_factory): for nodenum in range(1, 6): node = vars()[f"node{nodenum}"] SETTINGS[node] = { - "call": tuple([ProcessID, {"id": nodenum}]), + "call": tuple([ProcessID, None, {"id": nodenum}]), } node_save = Node(step="save", leadtime=leadtime) @@ -92,7 +91,7 @@ def graph(tmp_path_factory): # we let SaveJson expand the filepath for us from the node properties (leadtime) SETTINGS[node_save] = { "call": tuple( - [SaveJson, {"filepath": f"{tmp_dir}/result_{{leadtime}}.json"}] + [SaveJson, None, {"filepath": f"{tmp_dir}/result_{{leadtime}}.json"}] ) } return EDGES, SETTINGS, output_files @@ -115,16 +114,13 @@ def test_execution(graph, scheduler): # parallel execution performance. debug = False EDGES, SETTINGS, output_files = graph - with patch("dagrunner.execute_graph.logger.ServerContext"), patch( - "dagrunner.execute_graph.logger.client_attach_socket_handler" - ): - graph = ExecuteGraph( - (EDGES, SETTINGS), - num_workers=3, - scheduler=scheduler, - verbose=False, - debug=debug, - )() + graph = ExecuteGraph( + (EDGES, SETTINGS), + num_workers=3, + scheduler=scheduler, + verbose=False, + debug=debug, + )() for output_file in output_files: with open(output_file, "r") as file: # two of them are expected since we have two leadtime branches @@ -145,17 +141,14 @@ def test_skip_execution(graph): # skip execution of the second branch SETTINGS[Node(step="step2", leadtime=HOUR)] = { - "call": tuple([SkipExe, {"id": 2}]), + "call": tuple([SkipExe, None, {"id": 2}]), } - with patch("dagrunner.execute_graph.logger.ServerContext"), patch( - "dagrunner.execute_graph.logger.client_attach_socket_handler" - ): - graph = ExecuteGraph( - (EDGES, SETTINGS), - num_workers=3, - scheduler=scheduler, - verbose=False, - )() + graph = ExecuteGraph( + (EDGES, SETTINGS), + num_workers=3, + scheduler=scheduler, + verbose=False, + )() output_file = output_files[0] with open(output_file, "r") as file: # two of them are expected since we have two leadtime branches @@ -178,19 +171,16 @@ def test_multiprocessing_error_handling(graph): # # skip execution of the second branch SETTINGS[Node(step="step2", leadtime=HOUR)] = { - "call": tuple([RaiseErr, {"id": 2}]), + "call": tuple([RaiseErr, None, {"id": 2}]), } - with patch("dagrunner.execute_graph.logger.ServerContext"), patch( - "dagrunner.execute_graph.logger.client_attach_socket_handler" - ): - graph = ExecuteGraph( - (EDGES, SETTINGS), - num_workers=3, - scheduler=scheduler, - verbose=False, - ) - with pytest.raises(RuntimeError, match="RaiseErr"): - graph() + graph = ExecuteGraph( + (EDGES, SETTINGS), + num_workers=3, + scheduler=scheduler, + verbose=False, + ) + with pytest.raises(RuntimeError, match="RaiseErr"): + graph() @pytest.fixture() @@ -204,10 +194,10 @@ def graph_input(): EDGES.append([node1, node2]) SETTINGS[node1] = { - "call": tuple([Input, {"filepath": "{step}_{leadtime}"}]), + "call": tuple([Input, None, {"filepath": "{step}_{leadtime}"}]), } SETTINGS[node2] = { - "call": tuple([lambda x: x, {}]), + "call": tuple([lambda x: x]), } return EDGES, SETTINGS @@ -219,14 +209,11 @@ def test_override_node_property_with_setting(graph_input, capsys): new_step = "altered_step" SETTINGS[Node(step="step1", leadtime=1)] |= {"step": new_step} - with patch("dagrunner.execute_graph.logger.ServerContext"), patch( - "dagrunner.execute_graph.logger.client_attach_socket_handler" - ): - _ = ExecuteGraph( - (EDGES, SETTINGS), - num_workers=1, - scheduler=scheduler, - verbose=True, - )() + _ = ExecuteGraph( + (EDGES, SETTINGS), + num_workers=1, + scheduler=scheduler, + verbose=True, + )() output = capsys.readouterr() assert f"result: {new_step}_1" in output.out diff --git a/dagrunner/tests/execute_graph/test_plugin_executor.py b/dagrunner/tests/execute_graph/test_plugin_executor.py index 3ec70a5..81c89b1 100644 --- a/dagrunner/tests/execute_graph/test_plugin_executor.py +++ b/dagrunner/tests/execute_graph/test_plugin_executor.py @@ -4,98 +4,151 @@ # See LICENSE in the root of the repository for full licensing details. from unittest import mock +import pytest + from dagrunner.execute_graph import plugin_executor class DummyPlugin: - def __init__(self, iarg1, ikwarg1=None) -> None: - self._iarg1 = iarg1 - self._ikwarg1 = ikwarg1 + def __init__(self, init_named_arg, init_named_kwarg=None, **init_kwargs) -> None: + self._init_named_arg = init_named_arg + self._init_named_kwarg = init_named_kwarg + self._init_kwargs = init_kwargs - def __call__(self, *args, kwarg1=None, **kwargs): + def __call__( + self, *call_args, call_named_arg, call_named_kwarg=None, **call_kwargs + ): return ( - f"iarg1={self._iarg1}; ikwarg1={self._ikwarg1}; args={args}; " - f"kwarg1={kwarg1}; kwargs={kwargs}" + f"init_kwargs={self._init_kwargs}; " + f"init_named_arg={self._init_named_arg}; init_named_kwarg={self._init_named_kwarg}; " + f"call_args={call_args}; call_kwargs={call_kwargs}; " + f"call_named_arg={call_named_arg}; call_named_kwarg={call_named_kwarg}; " ) -@mock.patch("dagrunner.execute_graph.logger.client_attach_socket_handler") -def test_pass_class_arg_kwargs(mock_client_attach_socket_handler): - """Test passing named parameters to plugin class and __call__ method.""" - args = (mock.sentinel.arg1, mock.sentinel.arg2) - call = tuple( - [ +class DummyPluginNoNamedParam: + def __init__(self, **init_kwargs) -> None: + self._init_kwargs = init_kwargs + + def __call__(self, *call_args, **call_kwargs): + return ( + f"init_kwargs={self._init_kwargs}; " + f"call_args={call_args}; call_kwargs={call_kwargs}; " + ) + + +@pytest.mark.parametrize( + "plugin, init_args, call_args, target", + [ + # Passing class init and call args + ( DummyPlugin, - {"iarg1": mock.sentinel.iarg1, "ikwarg1": mock.sentinel.ikwarg1}, - {"kwarg1": mock.sentinel.kwarg1}, - ] - ) + { + "init_named_arg": mock.sentinel.init_named_arg, + "init_named_kwarg": mock.sentinel.init_named_kwarg, + "init_other_kwarg": mock.sentinel.init_other_kwarg, + }, + { + "call_named_arg": mock.sentinel.call_named_arg, + "call_named_kwarg": mock.sentinel.call_named_kwarg, + "call_other_kwarg": mock.sentinel.call_other_kwarg, + }, + ( + "init_kwargs={'init_other_kwarg': sentinel.init_other_kwarg}; " + "init_named_arg=sentinel.init_named_arg; " + "init_named_kwarg=sentinel.init_named_kwarg; " + "call_args=(sentinel.arg1, sentinel.arg2); " + "call_kwargs={'call_other_kwarg': sentinel.call_other_kwarg}; " + "call_named_arg=sentinel.call_named_arg; call_named_kwarg=sentinel.call_named_kwarg; " + ), + ), + # Passing class init args only + ( + DummyPluginNoNamedParam, + {"init_other_kwarg": mock.sentinel.init_other_kwarg}, + None, + ( + "init_kwargs={'init_other_kwarg': sentinel.init_other_kwarg}; " + "call_args=(sentinel.arg1, sentinel.arg2); " + "call_kwargs={}; " + ), + ), + # Passing class call args only + ( + DummyPluginNoNamedParam, + None, + {"call_other_kwarg": mock.sentinel.call_other_kwarg}, + ( + "init_kwargs={}; " + "call_args=(sentinel.arg1, sentinel.arg2); " + "call_kwargs={'call_other_kwarg': sentinel.call_other_kwarg}; " + ), + ), + ], +) +def test_pass_class_arg_kwargs(plugin, init_args, call_args, target): + """Test passing named parameters to plugin class initialisation and __call__ method.""" + args = (mock.sentinel.arg1, mock.sentinel.arg2) + call = tuple([plugin, init_args, call_args]) res = plugin_executor(*args, call=call) - assert res == ( - "iarg1=sentinel.iarg1; ikwarg1=sentinel.ikwarg1; " - "args=(sentinel.arg1, sentinel.arg2); kwarg1=sentinel.kwarg1; " - "kwargs={}" - ) + assert res == target -@mock.patch("dagrunner.execute_graph.logger.client_attach_socket_handler") -def test_pass_common_args(mock_client_attach_socket_handler): +def test_pass_common_args(): """Passing 'common args', some relevant to class init and some to call method.""" args = (mock.sentinel.arg1, mock.sentinel.arg2) common_kwargs = { - "ikwarg1": mock.sentinel.ikwarg1, - "kwarg1": mock.sentinel.kwarg1, - "iarg1": mock.sentinel.iarg1, + "init_named_arg": mock.sentinel.init_named_arg, + "init_named_kwarg": mock.sentinel.init_named_kwarg, + "other_kwargs": mock.sentinel.other_kwargs, # this should be ignored (as not part of class signature) + "call_named_arg": mock.sentinel.call_named_arg, + "call_named_kwarg": mock.sentinel.call_named_kwarg, } - - # call without common args (iarg1 is positional so non-optional) - call = tuple([DummyPlugin, {"iarg1": mock.sentinel.iarg1}, {}]) - res = plugin_executor(*args, call=call) - assert res == ( - "iarg1=sentinel.iarg1; ikwarg1=None; args=(sentinel.arg1, " - "sentinel.arg2); kwarg1=None; kwargs={}" + target = ( + "init_kwargs={}; " + "init_named_arg=sentinel.init_named_arg; " + "init_named_kwarg=sentinel.init_named_kwarg; " + "call_args=(sentinel.arg1, sentinel.arg2); " + "call_kwargs={}; " + "call_named_arg=sentinel.call_named_arg; " + "call_named_kwarg=sentinel.call_named_kwarg; " ) - # call with common args call = tuple([DummyPlugin, {}, {}]) res = plugin_executor(*args, call=call, common_kwargs=common_kwargs) - assert res == ( - "iarg1=sentinel.iarg1; ikwarg1=sentinel.ikwarg1; " - "args=(sentinel.arg1, sentinel.arg2); kwarg1=sentinel.kwarg1; " - "kwargs={}" - ) - - -class DummyPlugin2: - """Plugin that is reliant on data not explicitly defined in its UI.""" - - def __call__(self, *args, **kwargs): - return f"args={args}; kwargs={kwargs}" + assert res == target -@mock.patch("dagrunner.execute_graph.logger.client_attach_socket_handler") -def test_pass_common_args_via_override(mock_client_attach_socket_handler): - """ - Passing 'common args' to a plugin that doesn't have such arguments - defined in its signature. Instead, filter out those that aren't - specified in the graph. - """ +def test_pass_common_args_override(): + """Passing 'common args', some relevant to class init and some to call method.""" + args = (mock.sentinel.arg1, mock.sentinel.arg2) common_kwargs = { - "kwarg1": mock.sentinel.kwarg1, - "kwarg2": mock.sentinel.kwarg2, - "kwarg3": mock.sentinel.kwarg3, + "init_named_arg": mock.sentinel.init_named_arg_override, + "init_named_kwarg": mock.sentinel.init_named_kwarg_override, + "call_named_arg": mock.sentinel.call_named_arg_override, + "call_named_kwarg": mock.sentinel.call_named_kwarg_override, } - args = [] call = tuple( [ - DummyPlugin2, + DummyPlugin, { - "kwarg1": mock.sentinel.kwarg1, - "kwarg2": mock.sentinel.kwarg2, + "init_named_arg": mock.sentinel.init_named_arg, + "init_named_kwarg": mock.sentinel.init_named_kwarg, + }, + { + "call_named_arg": mock.sentinel.call_named_arg, + "call_named_kwarg": mock.sentinel.call_named_kwarg, }, ] ) - res = plugin_executor(*args, call=call, common_kwargs=common_kwargs) - assert ( - res == "args=(); kwargs={'kwarg1': sentinel.kwarg1, 'kwarg2': sentinel.kwarg2}" + target = ( + "init_kwargs={}; " + "init_named_arg=sentinel.init_named_arg_override; " + "init_named_kwarg=sentinel.init_named_kwarg_override; " + "call_args=(sentinel.arg1, sentinel.arg2); " + "call_kwargs={}; " + "call_named_arg=sentinel.call_named_arg_override; " + "call_named_kwarg=sentinel.call_named_kwarg_override; " ) + res = plugin_executor(*args, call=call, common_kwargs=common_kwargs) + assert res == target diff --git a/dagrunner/tests/utils/logging/test_integration.py b/dagrunner/tests/utils/logging/test_integration.py index 470e8a3..48ebdc0 100644 --- a/dagrunner/tests/utils/logging/test_integration.py +++ b/dagrunner/tests/utils/logging/test_integration.py @@ -2,14 +2,18 @@ # # This file is part of 'dagrunner' and is released under the BSD 3-Clause license. # See LICENSE in the root of the repository for full licensing details. +import inspect import logging import os import sqlite3 import subprocess +import sys +import time import pytest -from dagrunner.utils.logger import ServerContext +import dagrunner +from dagrunner.utils.logger import client_attach_socket_handler @pytest.fixture @@ -17,20 +21,36 @@ def sqlite_filepath(tmp_path): return tmp_path / "test_logs.sqlite" -def gen_client_code(loggers): - code = "import logging;" - code += "from dagrunner.utils.logger import client_attach_socket_handler;" - code += "client_attach_socket_handler();" +@pytest.fixture +def server(sqlite_filepath): + pythonpath = os.path.dirname(os.path.dirname(inspect.getfile(dagrunner))) + env = os.environ.copy() + env["PYTHONPATH"] = f"{os.path.dirname(__file__)}/../../../utils/:{pythonpath}" + + # Start the server process + server_proc = subprocess.Popen( + [ + sys.executable, + "-m", + "logger", + "--sqlite-filepath", + str(sqlite_filepath), + "--verbose", + ], + env=env, + ) + + # Wait for the server to start (adjust as needed) + time.sleep(1) # Can use a more robust method to check server readiness - for msg, lvlname, name in loggers: - if lvlname: - code += f"logging.getLogger('{lvlname}').{name}('{msg}');" - else: - code += f"logging.{name}('{msg}');" - return code + yield server_proc, sqlite_filepath + # Teardown: Kill the server after tests are done + server_proc.terminate() + server_proc.wait() -def test_sqlitedb(sqlite_filepath, caplog): + +def test_sqlitedb(server, caplog): test_inputs = ( ["Python is versatile and powerful.", "root", "info"], ["Lists store collections of items.", "myapp.area1", "debug"], @@ -38,11 +58,9 @@ def test_sqlitedb(sqlite_filepath, caplog): ["Indentation defines code blocks.", "myapp.area2", "warning"], ["Libraries extend Pythons capabilities.", "myapp.area2", "error"], ) - client_code = gen_client_code(test_inputs) - with ServerContext(sqlite_filepath=sqlite_filepath, verbose=True): - subprocess.run( - ["python", "-c", client_code], capture_output=True, text=True, check=True - ) + client_attach_socket_handler() + for msg, lvlname, name in test_inputs: + getattr(logging.getLogger(lvlname), name)(msg) # Check log messages assert len(caplog.record_tuples) == len(test_inputs) @@ -54,8 +72,12 @@ def test_sqlitedb(sqlite_filepath, caplog): == record ) + server_proc, sqlite_filepath = server + time.sleep(1) # wait for db write to complete + server_proc.terminate() + # Check there are any records in the database - conn = sqlite3.connect(sqlite_filepath) + conn = sqlite3.connect(str(sqlite_filepath)) cursor = conn.cursor() cursor.execute("SELECT COUNT(*) FROM logs") count = cursor.fetchone()[0] diff --git a/dagrunner/utils/logger.py b/dagrunner/utils/logger.py index d914a05..d2ea607 100644 --- a/dagrunner/utils/logger.py +++ b/dagrunner/utils/logger.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # (C) Crown Copyright, Met Office. All rights reserved. # # This file is part of 'dagrunner' and is released under the BSD 3-Clause license. @@ -25,10 +24,12 @@ import logging import logging.handlers +import os import pickle import queue import socket import socketserver +import sqlite3 import struct from dagrunner.utils import function_to_argparse_parse_args @@ -95,12 +96,12 @@ def handle(self): record = logging.makeLogRecord(obj) # Modify record to include hostname record.hostname = socket.gethostname() - self.handle_log_record(record) - # Push log record to the queue for database writing if self.server.log_queue is not None: self.server.log_queue.put(record) + self.handle_log_record(record) + def unpickle(self, data): return pickle.loads(data) @@ -135,16 +136,14 @@ def __init__( port=logging.handlers.DEFAULT_TCP_LOGGING_PORT, handler=LogRecordStreamHandler, log_queue=None, - queue_handler=None, ): socketserver.ThreadingTCPServer.__init__(self, (host, port), handler) self.abort = 0 self.timeout = 1 self.logname = None self.log_queue = log_queue # Store the reference to the log queue - self.queue_handler = queue_handler - def serve_until_stopped(self): + def serve_until_stopped(self, queue_handler=None): import select abort = 0 @@ -152,50 +151,51 @@ def serve_until_stopped(self): rd, wr, ex = select.select([self.socket.fileno()], [], [], self.timeout) if rd: self.handle_request() - if self.queue_handler: - self.queue_handler.write(self.log_queue) + if queue_handler: + queue_handler.write(self.log_queue) abort = self.abort - if self.queue_handler: - self.queue_handler.write(self.log_queue) # Ensure all records are written - self.queue_handler.close() + if queue_handler: + queue_handler.write(self.log_queue) # Ensure all records are written + queue_handler.close() class SQLiteQueueHandler: def __init__(self, sqfile="logs.sqlite", verbose=False): self._sqfile = sqfile - self._conn = None + self._conn = sqlite3.connect(self._sqfile) # Connect to the SQLite database self._verbose = verbose + self._debug = False + sqlite3.enable_callback_tracebacks(self._debug) + self.write_table() - @property - def db(self): - if self._conn is None: - import sqlite3 - - if self._verbose: - print(f"Writing sqlite file: {self._sqfile}") - self._conn = sqlite3.connect(self._sqfile) # Connect to the SQLite database - cursor = self._conn.cursor() - cursor.execute(""" - CREATE TABLE IF NOT EXISTS logs ( - created TEXT, - name TEXT, - level TEXT, - message TEXT, - hostname TEXT, - process TEXT, - thread TEXT - ) - """) # Create the 'logs' table if it doesn't exist - return self._conn + def write_table(self): + if self._verbose: + print(f"Writing sqlite file table: {self._sqfile}") + # cursors are not thread-safe + cursor = self._conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS logs ( + created TEXT, + name TEXT, + level TEXT, + message TEXT, + hostname TEXT, + process TEXT, + thread TEXT + ) + """) # Create the 'logs' table if it doesn't exist + self._conn.commit() # Commit the transaction after all writes + cursor.close() def write(self, log_queue): if self._verbose: - print("Writing to sqlite file") + print(f"Writing row to sqlite file: {self._sqfile}") + # cursors are not thread-safe + cursor = self._conn.cursor() while not log_queue.empty(): record = log_queue.get() if self._verbose: print("Dequeued item:", record) - cursor = self.db.cursor() cursor.execute( "\n" "INSERT INTO logs " @@ -211,7 +211,8 @@ def write(self, log_queue): record.thread, ), ) - self.db.commit() # Commit the transaction + self._conn.commit() # Commit the transaction after all writes + cursor.close() def close(self): if self._conn: @@ -265,10 +266,19 @@ def start_logging_server( host=host, port=port, log_queue=log_queue, - queue_handler=sqlitequeue, ) - print("About to start TCP server...") - tcpserver.serve_until_stopped() + print( + "About to start TCP server...\n", + "HOST:", + host, + "PORT:", + port, + "PID:", + os.getpid(), + "SQLITE:", + sqlite_filepath, + ) + tcpserver.serve_until_stopped(queue_handler=sqlitequeue) def main():