diff --git a/.gitignore b/.gitignore index cd6f0ce..bdc3b1b 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ Icon venv build/ dist/ +*.idea *.egg-info *.egg-info test *.pyc diff --git a/src/exengine/integration_tests/test_preferred_thread_annotations.py b/src/exengine/integration_tests/test_preferred_thread_annotations.py index 808143d..823d0dd 100644 --- a/src/exengine/integration_tests/test_preferred_thread_annotations.py +++ b/src/exengine/integration_tests/test_preferred_thread_annotations.py @@ -21,11 +21,11 @@ class DecoratedEvent(ThreadRecordingEvent): class TestDevice(Device): - def __init__(self, name): - super().__init__(name, no_executor_attrs=('_attribute', 'set_attribute_thread', - 'get_attribute_thread', 'regular_method_thread', - 'decorated_method_thread')) + def __init__(self, _name, _engine): self._attribute = 123 + self.set_attribute_thread = "Not set" + self.regular_method_thread = "Not set" + self.decorated_method_thread = "Not set" @property def attribute(self): @@ -48,11 +48,11 @@ def decorated_method(self): @on_thread("CustomDeviceThread") class CustomThreadTestDevice(Device): - def __init__(self, name): - super().__init__(name, no_executor_attrs=('_attribute', - 'set_attribute_thread', 'get_attribute_thread', - 'regular_method_thread', 'decorated_method_thread')) + def __init__(self, _name, _engine): self._attribute = 123 + self.get_attribute_thread = "Not set" + self.set_attribute_thread = "Not set" + self.regular_method_thread = "Not set" @property def attribute(self): @@ -103,15 +103,16 @@ def test_device_attribute_access(engine): """ Test that device attribute access runs on the main thread when nothing else specified. """ - device = TestDevice("TestDevice") + device = TestDevice("TestDevice", engine) device.attribute = 'something' + read_back = device.attribute assert device.set_attribute_thread == _MAIN_THREAD_NAME def test_device_regular_method_access(engine): """ Test that device method access runs on the main thread when nothing else specified. """ - device = TestDevice("TestDevice") + device = TestDevice("TestDevice", engine) device.regular_method() assert device.regular_method_thread == _MAIN_THREAD_NAME @@ -119,7 +120,7 @@ def test_device_decorated_method_access(engine): """ Test that device method access runs on the main thread when nothing else specified. """ - device = TestDevice("TestDevice") + device = TestDevice("TestDevice", engine) device.decorated_method() assert device.decorated_method_thread == "CustomMethodThread" @@ -127,7 +128,7 @@ def test_custom_thread_device_attribute_access(engine): """ Test that device attribute access runs on the custom thread when specified. """ - custom_device = CustomThreadTestDevice("CustomDevice") + custom_device = CustomThreadTestDevice( "CustomDevice", engine) custom_device.attribute = 'something' assert custom_device.set_attribute_thread == "CustomDeviceThread" @@ -135,7 +136,7 @@ def test_custom_thread_device_property_access(engine): """ Test that device property access runs on the custom thread when specified. """ - custom_device = CustomThreadTestDevice("CustomDevice") + custom_device = CustomThreadTestDevice("CustomDevice", engine) custom_device.attribute = 'something' assert custom_device.set_attribute_thread == "CustomDeviceThread" @@ -145,8 +146,7 @@ def test_custom_thread_device_property_access(engine): @on_thread("OuterThread") class OuterThreadDevice(Device): - def __init__(self, name, inner_device): - super().__init__(name) + def __init__(self, _name, _engine, inner_device): self.inner_device = inner_device self.outer_thread = None @@ -157,8 +157,7 @@ def outer_method(self): @on_thread("InnerThread") class InnerThreadDevice(Device): - def __init__(self, name): - super().__init__(name) + def __init__(self, _name, _engine): self.inner_thread = None def inner_method(self): @@ -170,8 +169,8 @@ def test_nested_thread_switch(engine): Test that nested calls to methods with different thread specifications result in correct thread switches at each level. """ - inner_device = InnerThreadDevice("InnerDevice") - outer_device = OuterThreadDevice("OuterDevice", inner_device) + inner_device = InnerThreadDevice("InnerDevice", engine) + outer_device = OuterThreadDevice("OuterDevice", engine, inner_device) class OuterEvent(ExecutorEvent): def execute(self): @@ -209,7 +208,7 @@ def test_multiple_decorators(engine): """ Test that the thread decorator works correctly when combined with other decorators. """ - device = MultiDecoratedDevice("MultiDevice") + device = MultiDecoratedDevice("MultiDevice", engine) class MultiEvent(ExecutorEvent): def execute(self): diff --git a/src/exengine/kernel/data_handler.py b/src/exengine/kernel/data_handler.py index 5b2681b..10439ee 100644 --- a/src/exengine/kernel/data_handler.py +++ b/src/exengine/kernel/data_handler.py @@ -5,6 +5,7 @@ from pydantic.types import JsonValue from dataclasses import dataclass +from .executor import ExecutionEngine from .notification_base import DataStoredNotification from .data_coords import DataCoordinates from .data_storage_base import DataStorage @@ -49,17 +50,11 @@ class DataHandler: # This class must create at least one additional thread (the saving thread) # and may create another for processing data - def __init__(self, storage: DataStorage, + def __init__(self, engine: ExecutionEngine, storage: DataStorage, process_function: Callable[[DataCoordinates, npt.NDArray["Any"], JsonValue], Optional[Union[DataCoordinates, npt.NDArray["Any"], JsonValue, - Tuple[DataCoordinates, npt.NDArray["Any"], JsonValue]]]] = None, - _executor=None): - # delayed import to avoid circular imports - if _executor is None: - from .executor import ExecutionEngine - self._engine = ExecutionEngine.get_instance() - else: - self._engine = _executor + Tuple[DataCoordinates, npt.NDArray["Any"], JsonValue]]]] = None): + self._engine = engine self._storage = storage self._process_function = process_function self._intake_queue = _PeekableQueue() diff --git a/src/exengine/kernel/device.py b/src/exengine/kernel/device.py index 2f452da..a6ded09 100644 --- a/src/exengine/kernel/device.py +++ b/src/exengine/kernel/device.py @@ -1,294 +1,61 @@ -""" -Base class for all device_implementations that integrates with the execution engine and enables tokenization of device access. -""" -from abc import ABCMeta, ABC -from functools import wraps -from typing import Any, Dict, Callable, Sequence, Optional, Tuple, Iterable, Union -from weakref import WeakSet -from dataclasses import dataclass - -from .ex_event_base import ExecutorEvent from .executor import ExecutionEngine -import threading -import sys - - - -def _initialize_thread_patching(): - _python_debugger_active = any('pydevd' in sys.modules for frame in sys._current_frames().values()) - - # All threads that were created by code running on an executor thread, or created by threads that were created by - # code running on an executor thread etc. Don't want to auto-reroute these to the executor because this might have - # unintended consequences. So they need to be tracked and not rerouted - _within_executor_threads = WeakSet() - - # Keep this list accessible outside of class attributes to avoid infinite recursion - # Note: This is already defined at module level, so we don't redefine it here - - def thread_start_hook(thread): - # keep track of threads that were created by code running on an executor thread so calls on them - # dont get rerouted to the executor - if ExecutionEngine.get_instance() and ( - ExecutionEngine.on_any_executor_thread() or threading.current_thread() in _within_executor_threads): - _within_executor_threads.add(thread) - - # Monkey patch the threading module so we can monitor the creation of new threads - _original_thread_start = threading.Thread.start - - # Define a new start method that adds the hook - def _thread_start(self, *args, **kwargs): - try: - thread_start_hook(self) - _original_thread_start(self, *args, **kwargs) - except Exception as e: - print(f"Error in thread start hook: {e}") - # traceback.print_exc() - - # Replace the original start method with the new one - threading.Thread.start = _thread_start - threading.Thread._monkey_patched_start = True - - return _python_debugger_active, _within_executor_threads, _original_thread_start - - -# Call this function to initialize the thread patching -if not hasattr(threading.Thread, '_monkey_patched_start'): - _python_debugger_active, _within_executor_threads, _original_thread_start = _initialize_thread_patching() - _no_executor_attrs = ['_name', '_no_executor', '_no_executor_attrs', '_thread_name'] - - -@dataclass -class MethodCallEvent(ExecutorEvent): - def __init__(self, method_name: str, args: tuple, kwargs: Dict[str, Any], instance: Any): - super().__init__() - self.method_name = method_name - self.args = args - self.kwargs = kwargs - self.instance = instance - - def execute(self): - method = getattr(self.instance, self.method_name) - return method(*self.args, **self.kwargs) - -class GetAttrEvent(ExecutorEvent): - - def __init__(self, attr_name: str, instance: Any, method: Callable): - super().__init__() - self.attr_name = attr_name - self.instance = instance - self.method = method - - def execute(self): - return self.method(self.instance, self.attr_name) - -class SetAttrEvent(ExecutorEvent): - - def __init__(self, attr_name: str, value: Any, instance: Any, method: Callable): - super().__init__() - self.attr_name = attr_name - self.value = value - self.instance = instance - self.method = method - - def execute(self): - self.method(self.instance, self.attr_name, self.value) - -class DeviceMetaclass(ABCMeta): +# @staticmethod +# def is_debugger_thread(): +# if not _python_debugger_active: +# return False +# # This is a heuristic and may need adjustment based on the debugger used. +# debugger_thread_names = ["pydevd", "Debugger", "GetValueAsyncThreadDebug"] # Common names for debugger threads +# current_thread = threading.current_thread() +# # Check if current thread name contains any known debugger thread names +# return any(name in current_thread.name or name in str(current_thread.__class__.__name__) +# for name in debugger_thread_names) +# + +class Device: """ - Metaclass for device_implementations that wraps all methods and attributes in the device class to add the ability to - control their execution and access. This has two purposes: - - 1) Add the ability to record all method calls and attribute accesses for tokenization - 2) Add the ability to make all methods and attributes thread-safe by putting them on the Executor - 3) Automatically register all instances of the device with the ExecutionEngine - """ - @staticmethod - def wrap_for_executor(attr_name, attr_value): - if hasattr(attr_value, '_wrapped_for_executor'): - return attr_value - - # Add this block to handle properties - if isinstance(attr_value, property): - return property( - fget=DeviceMetaclass.wrap_for_executor(f"{attr_name}_getter", attr_value.fget) if attr_value.fget else None, - fset=DeviceMetaclass.wrap_for_executor(f"{attr_name}_setter", attr_value.fset) if attr_value.fset else None, - fdel=DeviceMetaclass.wrap_for_executor(f"{attr_name}_deleter", attr_value.fdel) if attr_value.fdel else None, - doc=attr_value.__doc__ - ) - - @wraps(attr_value) - def wrapper(self: 'Device', *args: Any, **kwargs: Any) -> Any: - if attr_name in _no_executor_attrs or self._no_executor: - return attr_value(self, *args, **kwargs) - if DeviceMetaclass._is_reroute_exempted_thread(): - return attr_value(self, *args, **kwargs) - # check for method-level preferred thread name first, then class-level - thread_name = getattr(attr_value, '_thread_name', None) or getattr(self, '_thread_name', None) - if ExecutionEngine.on_any_executor_thread(): - # check for device-level preferred thread - if thread_name is None or threading.current_thread().name == thread_name: - return attr_value(self, *args, **kwargs) - event = MethodCallEvent(method_name=attr_name, args=args, kwargs=kwargs, instance=self) - return ExecutionEngine.get_instance().submit(event, thread_name=thread_name).await_execution() - - wrapper._wrapped_for_executor = True - return wrapper - - @staticmethod - def is_debugger_thread(): - if not _python_debugger_active: - return False - # This is a heuristic and may need adjustment based on the debugger used. - debugger_thread_names = ["pydevd", "Debugger", "GetValueAsyncThreadDebug"] # Common names for debugger threads - current_thread = threading.current_thread() - # Check if current thread name contains any known debugger thread names - return any(name in current_thread.name or name in str(current_thread.__class__.__name__) - for name in debugger_thread_names) - - @staticmethod - def _is_reroute_exempted_thread() -> bool: - return (DeviceMetaclass.is_debugger_thread() or threading.current_thread() in _within_executor_threads) + Base class that causes the object to be automatically registered on creation. - @staticmethod - def find_in_bases(bases, method_name): - for base in bases: - if hasattr(base, method_name): - return getattr(base, method_name) - return None + Usage: + class MyDevice(Device): + def __init__(self, name: str, engine: ExecutionEngine, ...): + ... - def __new__(mcs, name: str, bases: tuple, attrs: dict) -> Any: - new_attrs = {} - for attr_name, attr_value in attrs.items(): - if not attr_name.startswith('_'): - if isinstance(attr_value, property): # Property - new_attrs[attr_name] = mcs.wrap_for_executor(attr_name, attr_value) - elif callable(attr_value): # Regular method - new_attrs[attr_name] = mcs.wrap_for_executor(attr_name, attr_value) - else: # Attribute - new_attrs[attr_name] = attr_value - else: - new_attrs[attr_name] = attr_value + engine = ExecutionEngine() + device = MyDevice("device_name", engine, ...) + Has the same effect as: + class MyDevice: + ... - original_setattr = attrs.get('__setattr__') or mcs.find_in_bases(bases, '__setattr__') or object.__setattr__ - def getattribute_with_fallback(self, name): - """ Wrap the getattribute method to fallback to getattr if an attribute is not found """ - try: - return object.__getattribute__(self, name) - except AttributeError: - try: - return self.__getattr__(name) - except AttributeError as e: - if _python_debugger_active and (name == 'shape' or name == '__len__'): - pass # This prevents a bunch of irrelevant errors in the Pycharm debugger - else: - raise e - - def __getattribute__(self: 'Device', name: str) -> Any: - if name in _no_executor_attrs or self._no_executor: - return object.__getattribute__(self, name) - if DeviceMetaclass._is_reroute_exempted_thread(): - return getattribute_with_fallback(self, name) - thread_name = getattr(self, '_thread_name', None) - if ExecutionEngine.on_any_executor_thread(): - # check for device-level preferred thread - if thread_name is None or threading.current_thread().name == thread_name: - return getattribute_with_fallback(self, name) - event = GetAttrEvent(attr_name=name, instance=self, method=getattribute_with_fallback) - return ExecutionEngine.get_instance().submit(event, thread_name=thread_name).await_execution() - - def __setattr__(self: 'Device', name: str, value: Any) -> None: - if name in _no_executor_attrs or self._no_executor: - return original_setattr(self, name, value) - if DeviceMetaclass._is_reroute_exempted_thread(): - return original_setattr(self, name, value) - thread_name = getattr(self, '_thread_name', None) - if ExecutionEngine.on_any_executor_thread(): - # Check for device-level preferred thread - if thread_name is None or threading.current_thread().name == thread_name: - return original_setattr(self, name, value) - event = SetAttrEvent(attr_name=name, value=value, instance=self, method=original_setattr) - ExecutionEngine.get_instance().submit(event, thread_name=thread_name).await_execution() - - new_attrs['__getattribute__'] = __getattribute__ - new_attrs['__setattr__'] = __setattr__ - - new_attrs['_no_executor'] = True # For startup - new_attrs['_no_executor_attrs'] = _no_executor_attrs - - - - # Create the class - cls = super().__new__(mcs, name, bases, new_attrs) - - # Add automatic registration to the executor - original_init = cls.__init__ - - @wraps(original_init) - def init_and_register(self, *args, **kwargs): - original_init(self, *args, **kwargs) - # Register the instance with the executor - if hasattr(self, '_name') and hasattr(self, '_no_executor') and not self._no_executor: - ExecutionEngine.register_device(self._name, self) - - # Use setattr instead of direct assignment - setattr(cls, '__init__', init_and_register) - - - return cls - - -class Device(ABC, metaclass=DeviceMetaclass): + engine = ExecutionEngine() + device = engine.register("device_name", MyDevice(...)) """ - Required base class for all devices usable with the execution engine - - Device classes should inherit from this class and implement the abstract methods. The DeviceMetaclass will wrap all - methods and attributes in the class to make them thread-safe and to optionally record all method calls and - attribute accesses. - - Attributes with a trailing _noexec will not be wrapped and will be executed directly on the calling thread. This is - useful for attributes that are not hardware related and can bypass the complexity of the executor. - - Device implementations can also implement functionality through properties (i.e. attributes that are actually - methods) by defining a getter and setter method for the property. - """ - - def __init__(self, name: str, no_executor: bool = False, no_executor_attrs: Sequence[str] = ('_name', )): - """ - Create a new device - - :param name: The name of the device - :param no_executor: If True, all methods and attributes will be executed directly on the calling thread instead - of being rerouted to the executor - :param no_executor_attrs: If no_executor is False, this is a list of attribute names that will be executed - directly on the calling thread - """ - self._no_executor_attrs.extend(no_executor_attrs) - self._no_executor = no_executor - self._name = name - - - def get_allowed_property_values(self, property_name: str) -> Optional[list[str]]: - return None # By default, any value is allowed - - def is_property_read_only(self, property_name: str) -> bool: - return False # By default, properties are writable - - def get_property_limits(self, property_name: str) -> Tuple[Optional[float], Optional[float]]: - return (None, None) # By default, no limits - - def is_property_hardware_triggerable(self, property_name: str) -> bool: - return False # By default, properties are not hardware triggerable - - def get_triggerable_sequence_max_length(self, property_name: str) -> int: - raise NotImplementedError(f"get_triggerable_sequence_max_length is not implemented for {property_name}") - - def load_triggerable_sequence(self, property_name: str, event_sequence: Iterable[Union[str, float, int]]): - raise NotImplementedError(f"load_triggerable_sequence is not implemented for {property_name}") - - def start_triggerable_sequence(self, property_name: str): - raise NotImplementedError(f"start_triggerable_sequence is not implemented for {property_name}") - - def stop_triggerable_sequence(self, property_name: str): - raise NotImplementedError(f"stop_triggerable_sequence is not implemented for {property_name}") + def __new__(cls, name: str, engine: "ExecutionEngine", *args, **kwargs): + obj = super().__new__(cls) + obj.__init__(name, engine, *args, **kwargs) + return engine.register(name, obj) + + # def get_allowed_property_values(self, property_name: str) -> Optional[list[str]]: + # return None # By default, any value is allowed + # + # def is_property_read_only(self, property_name: str) -> bool: + # return False # By default, properties are writable + # + # def get_property_limits(self, property_name: str) -> Tuple[Optional[float], Optional[float]]: + # return (None, None) # By default, no limits + # + # def is_property_hardware_triggerable(self, property_name: str) -> bool: + # return False # By default, properties are not hardware triggerable + # + # def get_triggerable_sequence_max_length(self, property_name: str) -> int: + # raise NotImplementedError(f"get_triggerable_sequence_max_length is not implemented for {property_name}") + # + # def load_triggerable_sequence(self, property_name: str, event_sequence: Iterable[Union[str, float, int]]): + # raise NotImplementedError(f"load_triggerable_sequence is not implemented for {property_name}") + # + # def start_triggerable_sequence(self, property_name: str): + # raise NotImplementedError(f"start_triggerable_sequence is not implemented for {property_name}") + # + # def stop_triggerable_sequence(self, property_name: str): + # raise NotImplementedError(f"stop_triggerable_sequence is not implemented for {property_name}") diff --git a/src/exengine/kernel/ex_event_base.py b/src/exengine/kernel/ex_event_base.py index 6e6347c..da6f50b 100644 --- a/src/exengine/kernel/ex_event_base.py +++ b/src/exengine/kernel/ex_event_base.py @@ -46,11 +46,24 @@ class ExecutorEvent(ABC, metaclass=_ExecutorEventMeta): def __init__(self, *args, **kwargs): super().__init__() self._num_retries_on_exception = 0 + self._priority = 1 # lower number means higher priority self._finished = False self._initialized = False # Check for method-level preferred thread name first, then class-level self._thread_name = getattr(self.execute, '_thread_name', None) or getattr(self.__class__, '_thread_name', None) + def __lt__(self, other) -> bool: + """Implement the < operator to allow sorting events by priority""" + if other is None: + return True # always put 'None' at the end of the queue + return self._priority < other._priority + + def __gt__(self, other) -> bool: + """Implement the > operator to allow sorting events by priority""" + if other is None: + return False + return self._priority > other._priority + def _pre_execution(self, engine) -> ExecutionFuture: """ This is called automatically by the Executor and should not be overriden by subclasses. @@ -95,17 +108,22 @@ def _post_execution(self, return_value: Optional[Any] = None, exception: Optiona Method that is called after the event is executed to update acquisition futures about the event's status. This is called automatically by the Executor and should not be overriden by subclasses. + This method signals that the future is complete, so that any thread waiting on it can proceed. + Args: return_value: Return value of the event exception: Exception that was raised during execution, if any """ if self._future_weakref is None: raise Exception("Future not set for event") - future = self._future_weakref() self.finished = True - self._engine.publish_notification(EventExecutedNotification(payload=exception)) - if future is not None: - future._notify_execution_complete(return_value, exception) + try: + self._engine.publish_notification(EventExecutedNotification(payload=exception)) + finally: + future = self._future_weakref() + if future is not None: + print(f"Event {self} finished, notifying future") + future._notify_execution_complete(return_value, exception) @@ -116,16 +134,15 @@ class AnonymousCallableEvent(ExecutorEvent): The callable object should take no arguments and optionally return a value. """ def __init__(self, callable_obj: Callable[[], Any]): - super().__init__() - self.callable_obj = callable_obj # Check if the callable has no parameters (except for 'self' in case of methods) if not callable(callable_obj): raise TypeError("Callable object must be a function or method") - signature = inspect.signature(callable_obj) - if not all(param.default != param.empty or param.kind == param.VAR_POSITIONAL for param in - signature.parameters.values()): + if not inspect.signature(callable_obj).bind(): raise TypeError("Callable object must take no arguments") + super().__init__() + self.callable_obj = callable_obj + def execute(self): return self.callable_obj() \ No newline at end of file diff --git a/src/exengine/kernel/ex_future.py b/src/exengine/kernel/ex_future.py index bff351d..60a2fd4 100644 --- a/src/exengine/kernel/ex_future.py +++ b/src/exengine/kernel/ex_future.py @@ -15,10 +15,9 @@ class ExecutionFuture: def __init__(self, event: 'ExecutorEvent'): self.event = event - self._event_complete_condition: threading.Condition = threading.Condition() self._data_notification_condition: threading.Condition = threading.Condition() self._generic_notification_condition: threading.Condition = threading.Condition() - self._event_complete = False + self._event_complete = threading.Event() self._acquired_data_coordinates: Set[DataCoordinates] = set() @@ -40,10 +39,9 @@ def await_execution(self, timeout=None) -> Any: Block until the event is complete. If event.execute returns a value, it will be returned here. If event.execute raises an exception, it will be raised here as well """ - with self._event_complete_condition: - while not self._event_complete: - if not self._event_complete_condition.wait(timeout): - raise TimeoutError("Timed out waiting for event to complete") + print(f"awaiting{self.event}") + if not self._event_complete.wait(timeout): + raise TimeoutError("Timed out waiting for event to complete") if self._exception is not None: raise self._exception return self._return_value @@ -52,18 +50,15 @@ def is_execution_complete(self) -> bool: """ Check if the event has completed """ - with self._event_complete_condition: - return self._event_complete + return self._event_complete.is_set() def _notify_execution_complete(self, return_value: Any = None, exception: Exception = None): """ Notify the future that the event has completed """ - with self._event_complete_condition: - self._return_value = return_value - self._exception = exception - self._event_complete = True - self._event_complete_condition.notify_all() + self._return_value = return_value + self._exception = exception + self._event_complete.set() def _notify_of_event_notification(self, notification: Notification): diff --git a/src/exengine/kernel/executor.py b/src/exengine/kernel/executor.py index 26ae708..18df4c5 100644 --- a/src/exengine/kernel/executor.py +++ b/src/exengine/kernel/executor.py @@ -2,21 +2,34 @@ Class that executes acquistion events across a pool of threads """ import threading -from collections import deque -from typing import Deque import warnings import traceback -from typing import Union, Iterable, Callable, Type +from dataclasses import dataclass +from typing import Union, Iterable, Callable, Type, Dict, Any import queue import inspect from .notification_base import Notification, NotificationCategory from .ex_event_base import ExecutorEvent, AnonymousCallableEvent from .ex_future import ExecutionFuture +from .queue import PriorityQueue, Queue, Shutdown + +# todo: Add shutdown to __del__ +# todo: simplify worker threads: +# - remove enqueing on free thread -> replace by a thread pool mechanism +# - decouple enqueing and dequeing (related) +# - remove is_free and related overhead +# todo: simplify ExecutorEvent class and lifecycle _MAIN_THREAD_NAME = 'MainExecutorThread' _ANONYMOUS_THREAD_NAME = 'AnonymousExecutorThread' +class DeviceBase: + __slots__ = ('_engine', '_device') + def __init__(self, engine, wrapped_device): + self._engine = engine + self._device = wrapped_device + class MultipleExceptions(Exception): def __init__(self, exceptions: list[Exception]): self.exceptions = exceptions @@ -24,32 +37,88 @@ def __init__(self, exceptions: list[Exception]): super().__init__("Multiple exceptions occurred:\n" + "\n".join(messages)) class ExecutionEngine: - - _instance = None - _lock = threading.Lock() _debug = False - def __new__(cls, *args, **kwargs): - with cls._lock: - if cls._instance is None: - cls._instance = super().__new__(cls) - return cls._instance - def __init__(self): - self._exceptions = queue.Queue() + self._exceptions = Queue() self._devices = {} - self._notification_queue = queue.Queue() + self._notification_queue = Queue() self._notification_subscribers: list[Callable[[Notification], None]] = [] self._notification_subscriber_filters: list[Union[NotificationCategory, Type]] = [] self._notification_lock = threading.Lock() self._notification_thread = None - self._shutdown_event = threading.Event() + self._thread_managers = {} + self._start_new_thread(_MAIN_THREAD_NAME) + - with self._lock: - if not hasattr(self, '_initialized'): - self._thread_managers = {} - self._start_new_thread(_MAIN_THREAD_NAME) - self._initialized = True + + def register(self, id: str, obj: object): + """ + Wraps an object for use with the ExecutionEngine + + The wrapper exposes the public properties and attributes of the wrapped object, converting + all get and set access, as well as method calls to Events. + Private methods and attributes are not exposed. + + After wrapping, the original object should not be used directly anymore. + All access should be done through the wrapper, which takes care of thread safety, synchronization, etc. + + Args: + id: Unique id (name) of the device, used by the ExecutionEngine. + obj: object to wrap. The object should only be registered once. Use of the original object should be avoided after wrapping, + since access to the original object is not thread safe or otherwise managed by the ExecutionEngine. + """ + # + if any(d is obj for d in self._devices) or isinstance(obj, DeviceBase): + raise ValueError("Object already registered") + + # get a list of all properties and methods, including the ones in base classes + # Also process class annotations, for attributes that are not properties + class_hierarchy = inspect.getmro(obj.__class__) + all_dict = {} + for c in class_hierarchy[::-1]: + all_dict.update(c.__dict__) + annotations = c.__dict__.get('__annotations__', {}) + all_dict.update(annotations) + + # add all attributes that are not already in the dict + for n, a in obj.__dict__.items(): + if not n.startswith("_") and n not in all_dict: + all_dict[n] = None + + # create the wrapper class + class_dict = {} + slots = [] + for name, attribute in all_dict.items(): + if name.startswith('_'): + continue # skip private attributes + + if inspect.isfunction(attribute): + def method(self, *args, _name=name, **kwargs): + event = MethodCallEvent(method_name=_name, args=args, kwargs=kwargs, instance=self._device) + return self._engine.submit(event) + + class_dict[name] = method + else: + def getter(self, _name=name): + event = GetAttrEvent(attr_name=_name, instance=self._device, method=getattr) + return self._engine.submit(event).await_execution() + + def setter(self, value, _name=name): + event = SetAttrEvent(attr_name=_name, value=value, instance=self._device, method=setattr) + self._engine.submit(event).await_execution() + + has_setter = not isinstance(attribute, property) or attribute.fset is not None + class_dict[name] = property(getter, setter if has_setter else None, None, f"Wrapped attribute {name}") + if not isinstance(attribute, property): + slots.append(name) + + class_dict['__slots__'] = () # prevent addition of new attributes. + WrappedObject = type('_' + obj.__class__.__name__, (DeviceBase,), class_dict) + # todo: cache dynamically generated classes + wrapped = WrappedObject(self,obj) + self._devices[id] = wrapped + return wrapped def subscribe_to_notifications(self, subscriber: Callable[[Notification], None], notification_type: Union[NotificationCategory, Type] = None @@ -92,18 +161,23 @@ def unsubscribe_from_notifications(self, subscriber: Callable[[Notification], No self._notification_subscriber_filters.pop(index) def _notification_thread_run(self): - while not self._shutdown_event.is_set() or self._notification_queue.qsize() > 0: - try: - notification = self._notification_queue.get(timeout=1) - except queue.Empty: - continue - with self._notification_lock: - for subscriber, filter in zip(self._notification_subscribers, self._notification_subscriber_filters): - if filter is not None and isinstance(filter, type) and not isinstance(notification, filter): - continue # not interested in this type - if filter is not None and isinstance(filter, NotificationCategory) and notification.category != filter: - continue - subscriber(notification) + try: + while True: + notification = self._notification_queue.get() + try: + with self._notification_lock: + for subscriber, filter in zip(self._notification_subscribers, self._notification_subscriber_filters): + if filter is not None and isinstance(filter, type) and not isinstance(notification, filter): + continue # not interested in this type + if filter is not None and isinstance(filter, NotificationCategory) and notification.category != filter: + continue + subscriber(notification) + except Exception as e: + self._log_exception(e) + finally: + self._notification_queue.task_done() + except Shutdown: + pass def publish_notification(self, notification: Notification): """ @@ -111,57 +185,35 @@ def publish_notification(self, notification: Notification): """ self._notification_queue.put(notification) - @classmethod - def get_instance(cls) -> 'ExecutionEngine': - return cls._instance - - @classmethod - def get_device(cls, device_name): + def __getitem__(self, device_id: str): """ Get a device by name - """ - if device_name not in cls.get_instance()._devices: - raise ValueError(f"No device with name {device_name}") - return cls.get_instance()._devices[device_name] - @classmethod - def register_device(cls, name, device): - """ - Called automatically when a Device is created so that the ExecutionEngine can keep track of all devices - and look them up by their string names - """ - # Make sure there's not already a device with this name - executor = cls.get_instance() - if name is not None: - # only true after initialization, but this gets called after all the subclass constructors - if name in executor._devices and executor._devices[name] is not device: - raise ValueError(f"Device with name {name} already exists") - executor._devices[name] = device - - @classmethod - def on_main_executor_thread(cls): - """ - Check if the current thread is an executor thread + Args: + device_id: unique id of the device that was used in the call to register_device. + Returns: + device + Raises: + KeyError if a device with this id is not found. """ - return threading.current_thread().name is _MAIN_THREAD_NAME + return self._devices[device_id] - @classmethod - def on_any_executor_thread(cls): - if ExecutionEngine.get_instance() is None: - raise RuntimeError("on_any_executor_thread: ExecutionEngine has not been initialized") + @staticmethod + def on_any_executor_thread(): + #todo: remove result = (hasattr(threading.current_thread(), 'execution_engine_thread') and threading.current_thread().execution_engine_thread) return result def _start_new_thread(self, name): - self._thread_managers[name] = _ExecutionThreadManager(name) + self._thread_managers[name] = _ExecutionThreadManager(self, name) - def set_debug_mode(self, debug): + @staticmethod + def set_debug_mode(debug): ExecutionEngine._debug = debug - @classmethod - def _log_exception(cls, exception): - ExecutionEngine.get_instance()._exceptions.put(exception) + def _log_exception(self, exception): + self._exceptions.put(exception) def check_exceptions(self): """ @@ -176,7 +228,7 @@ def check_exceptions(self): else: raise MultipleExceptions(exceptions) - def submit(self, event_or_events: Union[ExecutorEvent, Iterable[ExecutorEvent]], thread_name=None, + def submit(self, event_or_events: Union[ExecutorEvent, Iterable[ExecutorEvent], Callable], thread_name=None, prioritize: bool = False, use_free_thread: bool = False) -> Union[ExecutionFuture, Iterable[ExecutionFuture]]: """ Submit one or more acquisition events or callable objects for execution. @@ -281,18 +333,17 @@ def shutdown(self): # For now just let the devices be garbage collected. # TODO: add explicit shutdowns for devices here? self._devices = None - self._shutdown_event.set() for thread in self._thread_managers.values(): thread.shutdown() for thread in self._thread_managers.values(): thread.join() - # Make sure the notification thread is stopped + # Make sure the notification thread is stopped if it was started at all if self._notification_thread is not None: - # It was never started if no one subscribed + self._notification_queue.shutdown() self._notification_thread.join() - # delete singleton instance - ExecutionEngine._instance = None + + class _ExecutionThreadManager: @@ -305,113 +356,141 @@ class _ExecutionThreadManager: or events in its queue with the is_free method. """ - _deque: Deque[ExecutorEvent] - thread: threading.Thread - - def __init__(self, name='UnnamedExectorThread'): - super().__init__() + def __init__(self, engine: ExecutionEngine, name='UnnamedExectorThread'): self.thread = threading.Thread(target=self._run_thread, name=name) self.thread.execution_engine_thread = True - self._deque = deque() - self._shutdown_event = threading.Event() - self._terminate_event = threading.Event() + # todo: use single queue for all threads in a pool + # todo: custom queue class or re-queuing mechanism that allows checking requirements for starting the operation? + self._queue = PriorityQueue() self._exception = None - self._event_executing = False - self._addition_condition = threading.Condition() + self._engine = engine + self._event_executing = threading.Event() self.thread.start() def join(self): self.thread.join() def _run_thread(self): - event = None - while True: - if self._terminate_event.is_set(): - return - if self._shutdown_event.is_set() and not self._deque: - return - # Event retrieval loop - while event is None: - with (self._addition_condition): - if not self._deque: - # wait until something is in the queue - self._addition_condition.wait() - if self._terminate_event.is_set(): - return - if self._shutdown_event.is_set() and not self._deque: - # awoken by a shutdown event and the queue is empty - return - event: ExecutorEvent = self._deque.popleft() - if not hasattr(event, '_num_retries_on_exception'): - warnings.warn("Event does not have num_retries_on_exception attribute, setting to 0") - event._num_retries_on_exception = 0 - num_retries = event._num_retries_on_exception - self._event_executing = True - - # Event execution loop - exception = None - return_val = None - for attempt_number in range(event._num_retries_on_exception + 1): - if self._terminate_event.is_set(): - return # Executor has been terminated + """Main loop for worker threads. + + A thread is stopped by sending a TerminateThreadEvent to it and optionally setting the _terminate_now flag. + When a TerminateThreadEvent is encountered in the queue, the thread will terminate and discard all subsequent events. + todo: possible race condition when high-priority event is added after termination event + If the _terminate_now flag is set, the thread will terminate as soon as possible. + """ + return_val = None + try: + while True: + event = self._queue.get(block=True) # raises Shutdown exception when thread is shutting down + self._exception = None try: - if ExecutionEngine._debug: - print("Executing event", event.__class__.__name__, threading.current_thread()) if event._finished: + # this is unrecoverable, never retry + # todo: move this check to the submit code, this will give earlier and more accurate feedback + event._retries_on_execution = 0 raise RuntimeError("Event ", event, " was already executed") + + self._event_executing.set() + if ExecutionEngine._debug: + print("Executing event", event.__class__.__name__, threading.current_thread()) return_val = event.execute() if ExecutionEngine._debug: print("Finished executing", event.__class__.__name__, threading.current_thread()) - break + self._event_executing.clear() + except Exception as e: - warnings.warn(f"{e} during execution of {event}" + (", retrying {num_retries} more times" - if num_retries > 0 else "")) - # traceback.print_exc() - exception = e - if exception is not None: - ExecutionEngine.get_instance()._log_exception(exception) - event._post_execution(return_value=return_val, exception=exception) - with self._addition_condition: - self._event_executing = False - event = None + if event._num_retries_on_exception > 0: + event._num_retries_on_exception -= 1 + event.priority = 0 # reschedule with high priority + # log warning and try again + warnings.warn(f"{e} during execution of {event}" + + f", retrying {event._num_retries_on_exception} more times") + continue # don't call post_execution just yet + else: + # give up + self._engine._log_exception(e) + self._exception = e + + finally: + self._queue.task_done() + + try: + event._post_execution(return_value=return_val, exception=self._exception) + except Exception as e: + self._engine._log_exception(e) + + except Shutdown: + pass + def is_free(self): """ return true if an event is not currently being executed and the queue is empty """ - with self._addition_condition: - return not self._event_executing and not self._deque and not \ - self._terminate_event.is_set() and not self._shutdown_event.is_set() + return not self._event_executing.is_set() and self._queue.empty() def submit_event(self, event, prioritize=False): """ Submit an event for execution on this thread. If prioritize is True, the event will be executed before any other events in the queue. + + Raises: + Shutdown: If the thread is shutting down """ - with self._addition_condition: - if self._shutdown_event.is_set() or self._terminate_event.is_set(): - raise RuntimeError("Cannot submit event to a thread that has been shutdown") - if prioritize: - self._deque.appendleft(event) - else: - self._deque.append(event) - self._addition_condition.notify_all() + if prioritize: + event.priority = 0 # place at front of queue + self._queue.put(event) def terminate(self): """ Stop the thread immediately, without waiting for the current event to finish """ - with self._addition_condition: - self._terminate_event.set() - self._shutdown_event.set() - self._addition_condition.notify_all() + self._queue.shutdown(immediately=True) self.thread.join() + def shutdown(self): """ Stop the thread and wait for it to finish """ - with self._addition_condition: - self._shutdown_event.set() - self._addition_condition.notify_all() + self._queue.shutdown(immediately=False) self.thread.join() + +@dataclass +class MethodCallEvent(ExecutorEvent): + + def __init__(self, method_name: str, args: tuple, kwargs: Dict[str, Any], instance: Any): + super().__init__() + self.method_name = method_name + self.args = args + self.kwargs = kwargs + self.instance = instance + + def execute(self): + method = getattr(self.instance, self.method_name) + return method(*self.args, **self.kwargs) + + +class GetAttrEvent(ExecutorEvent): + + def __init__(self, attr_name: str, instance: Any, method: Callable): + super().__init__() + self.attr_name = attr_name + self.instance = instance + self.method = method + + def execute(self): + return self.method(self.instance, self.attr_name) + + +class SetAttrEvent(ExecutorEvent): + + def __init__(self, attr_name: str, value: Any, instance: Any, method: Callable): + super().__init__() + self.attr_name = attr_name + self.value = value + self.instance = instance + self.method = method + + def execute(self): + self.method(self.instance, self.attr_name, self.value) diff --git a/src/exengine/kernel/queue.py b/src/exengine/kernel/queue.py new file mode 100644 index 0000000..5adbc08 --- /dev/null +++ b/src/exengine/kernel/queue.py @@ -0,0 +1,64 @@ +import queue +import threading +import warnings +from typing import Generic, TypeVar + +from exengine.kernel.ex_event_base import ExecutorEvent + +# Abortable queue object used by the engine +# For Python 3.13, such an object is provided by the standard library +# For older versions, we provide a compatible implementation + +if hasattr(queue, 'Shutdown'): + PriorityQueue = queue.PriorityQueue + Queue = queue.Queue + Shutdown = queue.Shutdown +else: + # Pre-Python 3.13 compatibility + Shutdown = type('Shutdown', (BaseException,), {}) + class ShutdownMixin: + def __init__(self): + super().__init__() + self._shutdown = threading.Event() + + def shutdown(self, immediately=False): + """Shuts down the queue 'immediately' or after the current items are processed + Does not wait for the shutdown to complete (see join). + Note: this inserts 'None' sentinel values in the queue to signal termination. + """ + already_shut_down = self._shutdown.is_set() + if already_shut_down: + warnings.warn("Queue already shut down", RuntimeWarning) + + self._shutdown.set() + if immediately: + # Clear the queue + try: + while self.get(block=False): + self.task_done() + except queue.Empty: + pass + + if not already_shut_down: + super().put(None) # activate the worker thread if it is waiting at 'get' + self.task_done() # don't count None as actual task, don't wait for it in 'join' + + def get(self, block=True, timeout=None): + retval = super().get(block, timeout) + if retval is None: + super().put(None) # activate the next worker thread if it is waiting at 'get' + self.task_done() # don't count None as actual task, don't wait for it in 'join' + raise Shutdown + else: + return retval + + def put(self, item, block=True, timeout=None): + if self._shutdown.is_set(): + raise Shutdown # thread is being shut down, cannot add more items + return super().put(item, block, timeout) + + class PriorityQueue(ShutdownMixin, queue.PriorityQueue): + pass + + class Queue(ShutdownMixin, queue.Queue): + pass \ No newline at end of file diff --git a/src/exengine/kernel/test/test_data_handler.py b/src/exengine/kernel/test/test_data_handler.py index 9988ff4..d6fb07f 100644 --- a/src/exengine/kernel/test/test_data_handler.py +++ b/src/exengine/kernel/test/test_data_handler.py @@ -44,7 +44,6 @@ def __contains__(self, coords: DataCoordinates) -> bool: @pytest.fixture def mock_execution_engine(monkeypatch): mock_engine = Mock(spec=ExecutionEngine) - monkeypatch.setattr(ExecutionEngine, 'get_instance', lambda: mock_engine) return mock_engine @pytest.fixture @@ -54,7 +53,7 @@ def mock_data_storage(): @pytest.fixture def data_handler(mock_data_storage, mock_execution_engine): - dh = DataHandler(mock_data_storage, _executor=mock_execution_engine) + dh = DataHandler(mock_execution_engine, mock_data_storage) yield dh dh.finish() @@ -80,10 +79,11 @@ def test_data_handler_processing_function(mock_data_storage): Test that DataHandler can process data using a provided processing function, and that data_handler.get() returns the processed data not the original data. """ + engine = ExecutionEngine() def process_function(coords, image, metadata): return coords, image * 2, metadata - handler_with_processing = DataHandler(mock_data_storage, process_function) + handler_with_processing = DataHandler(engine, mock_data_storage, process_function) coords = DataCoordinates({"time": 1, "channel": "DAPI", "z": 0}) image = np.array([[1, 2], [3, 4]], dtype=np.uint16) diff --git a/src/exengine/kernel/test/test_device.py b/src/exengine/kernel/test/test_device.py index 9614581..45ba96e 100644 --- a/src/exengine/kernel/test/test_device.py +++ b/src/exengine/kernel/test/test_device.py @@ -1,5 +1,7 @@ import pytest from unittest.mock import MagicMock + +from exengine import ExecutionEngine from exengine.kernel.device import Device @@ -55,10 +57,13 @@ def test_stop_triggerable_sequence(mock_device): mock_device.stop_triggerable_sequence('test_property') mock_device.stop_triggerable_sequence.assert_called_once_with('test_property') +@pytest.mark.skip("Not implemented, needs to change to more systematic metadata storage") class TestDeviceDefaults: @pytest.fixture def default_device(self): - return Device('default_device', no_executor=True) + engine = ExecutionEngine() + yield Device(engine=engine, name='default_device') + engine.shutdown() def test_get_allowed_property_values_default(self, default_device): assert default_device.get_allowed_property_values('test_property') is None diff --git a/src/exengine/kernel/test/test_executor.py b/src/exengine/kernel/test/test_executor.py index 386df68..8283ed8 100644 --- a/src/exengine/kernel/test/test_executor.py +++ b/src/exengine/kernel/test/test_executor.py @@ -4,101 +4,70 @@ """ import pytest -from unittest.mock import MagicMock from exengine.kernel.ex_event_base import ExecutorEvent from exengine.kernel.device import Device +from exengine.kernel.executor import ExecutionEngine import time @pytest.fixture() -def execution_engine(): - engine = ExecutionEngine() - yield engine - engine.shutdown() +def engine(): + e = ExecutionEngine() + yield e + e.shutdown() ############################################################################################# # Tests for automated rerouting of method calls to the ExecutionEngine to executor threads ############################################################################################# counter = 1 -class TestDevice(Device): +class TestDevice: + """ + These classes are automatically wrapped for use in an ExecutionEngine. + """ def __init__(self): - global counter - super().__init__(name=f'mock_device_{counter}', no_executor_attrs=('property_getter_monitor', 'property_setter_monitor')) - counter += 1 - self.property_getter_monitor = False - self.property_setter_monitor = False - self._test_attribute = None + self.test_attribute = None + self._test_property = None def test_method(self): - assert ExecutionEngine.on_any_executor_thread() - assert threading.current_thread().execution_engine_thread return True - def set_test_attribute(self, value): - assert ExecutionEngine.on_any_executor_thread() - assert threading.current_thread().execution_engine_thread - self._test_attribute = value - - def get_test_attribute(self): - assert ExecutionEngine.on_any_executor_thread() - assert threading.current_thread().execution_engine_thread - return self._test_attribute - @property def test_property(self): - assert ExecutionEngine.on_any_executor_thread() - self.property_getter_monitor = True - return self._test_attribute + return self._test_property @test_property.setter def test_property(self, value): - assert ExecutionEngine.on_any_executor_thread() - self.property_setter_monitor = True - self._test_attribute = value + self._test_property = value -def test_device_method_execution(execution_engine): - mock_device = TestDevice() - result = mock_device.test_method() +def test_device_method_execution(engine): + engine.register("mock_device0", TestDevice()) + result = engine["mock_device0"].test_method().await_execution() assert result is True -def test_device_attribute_setting(execution_engine): - mock_device = TestDevice() - - mock_device.set_test_attribute("test_value") - result = mock_device.get_test_attribute() +def test_device_attribute_setting(engine): + engine.register("mock_device0", TestDevice()) + engine["mock_device0"].test_attribute = "test_value" + result = engine["mock_device0"].test_attribute assert result == "test_value" -def test_device_attribute_direct_setting(execution_engine): - mock_device = TestDevice() - - mock_device.direct_set_attribute = "direct_test_value" - assert mock_device.direct_set_attribute == "direct_test_value" - -def test_multiple_method_calls(execution_engine): - mock_device = TestDevice() - - result1 = mock_device.test_method() - mock_device.set_test_attribute("test_value") - result2 = mock_device.get_test_attribute() +def test_multiple_method_calls(engine): + mock_device = engine.register("mock_device0", TestDevice()) + result1 = mock_device.test_method().await_execution() + mock_device.test_attribute = "test_value" + result2 = mock_device.test_attribute assert result1 is True assert result2 == "test_value" -def test_device_property_getter(execution_engine): - mock_device = TestDevice() - - _ = mock_device.test_property - assert mock_device.property_getter_monitor - -def test_device_property_setter(execution_engine): - mock_device = TestDevice() - - mock_device.test_property = "test_value" - assert mock_device.property_setter_monitor +def test_device_property_setting(engine): + mock_device = engine.register("mock_device0", TestDevice()) + engine["mock_device0"].test_property = "test_value" + result = mock_device.test_property + assert result == "test_value" ####################################################### # Tests for internal threads in Devices @@ -110,11 +79,8 @@ def test_device_property_setter(execution_engine): import threading -class ThreadCreatingDevice(Device): +class ThreadCreatingDevice: def __init__(self): - global counter - super().__init__(name=f'test{counter}') - counter += 1 self.test_attribute = None self._internal_thread_result = None self._nested_thread_result = None @@ -157,8 +123,7 @@ def threadpool_func(): with ThreadPoolExecutor() as executor: executor.submit(threadpool_func) - -def test_device_internal_thread(execution_engine): +def test_device_internal_thread(engine): """ Test that a thread created internally by a device is not treated as an executor thread. @@ -170,17 +135,16 @@ def test_device_internal_thread(execution_engine): it ran without raising any assertions about being on an executor thread """ print('integration_tests started') - device = ThreadCreatingDevice() + engine.register("thread_creator", ThreadCreatingDevice()) print('getting ready to create internal thread') - t = device.create_internal_thread() + t = engine["thread_creator"].create_internal_thread().await_execution() # t.join() # while device.test_attribute is None: # time.sleep(0.1) - assert device.test_attribute == "set_by_internal_thread" + assert engine["thread_creator"].test_attribute == "set_by_internal_thread" - -def test_device_nested_thread(execution_engine): +def test_device_nested_thread(engine): """ Test that a nested thread (a thread created by another thread within the device) is not treated as an executor thread. @@ -192,14 +156,13 @@ def test_device_nested_thread(execution_engine): 3. Checking that the nested thread successfully set an attribute, indicating that it ran without raising any assertions about being on an executor thread """ - device = ThreadCreatingDevice() + device = engine.register("thread_creator", ThreadCreatingDevice()) device.create_nested_thread() while device.test_attribute is None: time.sleep(0.1) assert device.test_attribute == "set_by_nested_thread" - -def test_device_threadpool_executor(execution_engine): +def test_device_threadpool_executor(engine): """ Test that a thread created by ThreadPoolExecutor within a device method is not treated as an executor thread. @@ -212,7 +175,7 @@ def test_device_threadpool_executor(execution_engine): 3. Checking that the function successfully set an attribute, indicating that it ran without raising any assertions about being on an executor thread """ - device = ThreadCreatingDevice() + device = engine.register("thread_creator", ThreadCreatingDevice()) device.use_threadpool_executor() while device.test_attribute is None: time.sleep(0.1) @@ -246,7 +209,7 @@ def create_sync_event(start_event, finish_event): return SyncEvent(start_event, finish_event) -def test_submit_single_event(execution_engine): +def test_submit_single_event(engine): """ Test submitting a single event to the ExecutionEngine. Verifies that the event is executed and returns an AcquisitionFuture. @@ -255,8 +218,8 @@ def test_submit_single_event(execution_engine): finish_event = threading.Event() event = create_sync_event(start_event, finish_event) - future = execution_engine.submit(event) - execution_engine.check_exceptions() + future = engine.submit(event) + engine.check_exceptions() start_event.wait() # Wait for the event to start executing finish_event.set() # Signal the event to finish @@ -266,7 +229,7 @@ def test_submit_single_event(execution_engine): assert event.executed -def test_submit_multiple_events(execution_engine): +def test_submit_multiple_events(engine): """ Test submitting multiple events to the ExecutionEngine. Verifies that all events are executed and return AcquisitionFutures. @@ -279,8 +242,8 @@ def test_submit_multiple_events(execution_engine): finish_event2 = threading.Event() event2 = create_sync_event(start_event2, finish_event2) - future1 = execution_engine.submit(event1) - future2 = execution_engine.submit(event2) + future1 = engine.submit(event1) + future2 = engine.submit(event2) start_event1.wait() # Wait for the first event to start executing finish_event1.set() # Signal the first event to finish @@ -294,7 +257,8 @@ def test_submit_multiple_events(execution_engine): assert event2.executed -def test_event_prioritization(execution_engine): +@pytest.mark.skip("This test is broken. Even though event3 gets priority, it may execute after event2.") +def test_event_prioritization(engine): """ Test event prioritization in the ExecutionEngine. Verifies that prioritized events are executed before non-prioritized events. @@ -311,11 +275,12 @@ def test_event_prioritization(execution_engine): finish_event3 = threading.Event() event3 = create_sync_event(start_event3, finish_event3) - execution_engine.submit(event1) + engine.submit(event1) start_event1.wait() # Wait for the first event to start executing - execution_engine.submit(event2) - execution_engine.submit(event3, prioritize=True) + engine.submit(event2) + # race condition, at this point the engine may or may not have started executing event2 + engine.submit(event3, prioritize=True) finish_event1.set() finish_event2.set() @@ -330,7 +295,7 @@ def test_event_prioritization(execution_engine): assert event3.executed -def test_use_free_thread_parallel_execution(execution_engine): +def test_use_free_thread_parallel_execution(engine): """ Test parallel execution using free threads in the ExecutionEngine. Verifies that events submitted with use_free_thread=True can execute in parallel. @@ -343,8 +308,8 @@ def test_use_free_thread_parallel_execution(execution_engine): finish_event2 = threading.Event() event2 = create_sync_event(start_event2, finish_event2) - execution_engine.submit(event1) - execution_engine.submit(event2, use_free_thread=True) + engine.submit(event1) + engine.submit(event2, use_free_thread=True) # Wait for both events to start executing assert start_event1.wait(timeout=5) @@ -365,7 +330,7 @@ def test_use_free_thread_parallel_execution(execution_engine): assert event2.executed -def test_single_execution_with_free_thread(execution_engine): +def test_single_execution_with_free_thread(engine): """ Test that each event is executed only once, even when using use_free_thread=True. Verifies that events are not executed multiple times regardless of submission method. @@ -378,8 +343,8 @@ def test_single_execution_with_free_thread(execution_engine): finish_event2 = threading.Event() event2 = create_sync_event(start_event2, finish_event2) - execution_engine.submit(event1) - execution_engine.submit(event2, use_free_thread=True) + engine.submit(event1) + engine.submit(event2, use_free_thread=True) # Wait for both events to start executing assert start_event1.wait(timeout=5) @@ -398,43 +363,43 @@ def test_single_execution_with_free_thread(execution_engine): assert event2.execute_count == 1 #### Callable submission tests #### -def test_submit_callable(execution_engine): +def test_submit_callable(engine): def simple_function(): return 42 - future = execution_engine.submit(simple_function) + future = engine.submit(simple_function) result = future.await_execution() assert result == 42 -def test_submit_lambda(execution_engine): - future = execution_engine.submit(lambda: "Hello, World!") +def test_submit_lambda(engine): + future = engine.submit(lambda: "Hello, World!") result = future.await_execution() assert result == "Hello, World!" -def test_class_method(execution_engine): +def test_class_method(engine): class TestClass: def test_method(self): return "Test method executed" - future = execution_engine.submit(TestClass().test_method) + future = engine.submit(TestClass().test_method) result = future.await_execution() assert result == "Test method executed" -def test_submit_mixed(execution_engine): +def test_submit_mixed(engine): class TestEvent(ExecutorEvent): def execute(self): return "Event executed" - futures = execution_engine.submit([TestEvent(), lambda: 42, lambda: "Lambda"]) + futures = engine.submit([TestEvent(), lambda: 42, lambda: "Lambda"]) results = [future.await_execution() for future in futures] assert results == ["Event executed", 42, "Lambda"] -def test_submit_invalid(execution_engine): +def test_submit_invalid(engine): with pytest.raises(TypeError): - execution_engine.submit(lambda x: x + 1) # Callable with arguments should raise TypeError + engine.submit(lambda x: x + 1) # Callable with arguments should raise TypeError with pytest.raises(TypeError): - execution_engine.submit("Not a callable") # Non-callable, non-ExecutorEvent should raise TypeError + engine.submit("Not a callable") # Non-callable, non-ExecutorEvent should raise TypeError ####################################################### # Tests for named thread functionalities ############## @@ -444,7 +409,7 @@ def test_submit_invalid(execution_engine): from exengine.kernel.executor import _ANONYMOUS_THREAD_NAME -def test_submit_to_main_thread(execution_engine): +def test_submit_to_main_thread(engine): """ Test submitting an event to the main thread. """ @@ -452,13 +417,13 @@ def test_submit_to_main_thread(execution_engine): finish_event = threading.Event() event = create_sync_event(start_event, finish_event) - future = execution_engine.submit(event) + future = engine.submit(event) start_event.wait() finish_event.set() assert event.executed_thread_name == _MAIN_THREAD_NAME -def test_submit_to_new_anonymous_thread(execution_engine): +def test_submit_to_new_anonymous_thread(engine): """ Test that submitting an event with use_free_thread=True creates a new anonymous thread if needed. """ @@ -471,11 +436,11 @@ def test_submit_to_new_anonymous_thread(execution_engine): event2 = create_sync_event(start_event2, finish_event2) # Submit first event to main thread - execution_engine.submit(event1) + engine.submit(event1) start_event1.wait() # Submit second event with use_free_thread=True - future2 = execution_engine.submit(event2, use_free_thread=True) + future2 = engine.submit(event2, use_free_thread=True) start_event2.wait() finish_event1.set() @@ -483,9 +448,9 @@ def test_submit_to_new_anonymous_thread(execution_engine): assert event1.executed_thread_name == _MAIN_THREAD_NAME assert event2.executed_thread_name.startswith(_ANONYMOUS_THREAD_NAME) - assert len(execution_engine._thread_managers) == 2 # Main thread + 1 anonymous thread + assert len(engine._thread_managers) == 2 # Main thread + 1 anonymous thread -def test_multiple_anonymous_threads(execution_engine): +def test_multiple_anonymous_threads(engine): """ Test creation of multiple anonymous threads when submitting multiple events with use_free_thread=True. """ @@ -502,7 +467,7 @@ def test_multiple_anonymous_threads(execution_engine): start_events.append(start_event) finish_events.append(finish_event) - futures = [execution_engine.submit(event, use_free_thread=True) for event in events] + futures = [engine.submit(event, use_free_thread=True) for event in events] for start_event in start_events: start_event.wait() @@ -513,9 +478,9 @@ def test_multiple_anonymous_threads(execution_engine): thread_names = set(event.executed_thread_name for event in events) assert len(thread_names) == num_events # Each event should be on a different thread assert all(name.startswith(_ANONYMOUS_THREAD_NAME) or name == _MAIN_THREAD_NAME for name in thread_names) - assert len(execution_engine._thread_managers) == num_events # num_events anonymous threads + assert len(engine._thread_managers) == num_events # num_events anonymous threads -def test_reuse_named_thread(execution_engine): +def test_reuse_named_thread(engine): """ Test that submitting multiple events to the same named thread reuses that thread. """ @@ -533,7 +498,7 @@ def test_reuse_named_thread(execution_engine): start_events.append(start_event) finish_events.append(finish_event) - futures = [execution_engine.submit(event, thread_name=thread_name) for event in events] + futures = [engine.submit(event, thread_name=thread_name) for event in events] for finish_event in finish_events: finish_event.set() @@ -542,4 +507,4 @@ def test_reuse_named_thread(execution_engine): start_event.wait() assert all(event.executed_thread_name == thread_name for event in events) - assert len(execution_engine._thread_managers) == 2 # Main thread + 1 custom named thread \ No newline at end of file + assert len(engine._thread_managers) == 2 # Main thread + 1 custom named thread \ No newline at end of file diff --git a/src/exengine/kernel/test/test_futures.py b/src/exengine/kernel/test/test_futures.py index dc6f877..3230350 100644 --- a/src/exengine/kernel/test/test_futures.py +++ b/src/exengine/kernel/test/test_futures.py @@ -64,7 +64,7 @@ def complete_event(): thread = threading.Thread(target=complete_event) thread.start() execution_future.await_execution(timeout=5) - assert execution_future._event_complete + assert execution_future._event_complete.is_set() def test_notify_data(execution_future): diff --git a/src/exengine/kernel/test/test_generic_device.py b/src/exengine/kernel/test/test_generic_device.py new file mode 100644 index 0000000..f994ab0 --- /dev/null +++ b/src/exengine/kernel/test/test_generic_device.py @@ -0,0 +1,107 @@ +import inspect + +import numpy as np +import openwfs +from openwfs.simulation import Camera, StaticSource +from openwfs.processors import SingleRoi + +import pytest + +from exengine import ExecutionEngine +from exengine.kernel.device import Device +from exengine.kernel.executor import MethodCallEvent, GetAttrEvent, SetAttrEvent +from exengine.kernel.ex_future import ExecutionFuture + +""" +Tests wrapping a genric object for use with the ExecutionEngine +""" + +class TestObject: + """Generic object for testing + + The object has properties with getters and setters, read-only properties, attributes + and methods. + + The wrapper exposes the public properties and attributes of the wrapped object, converting + all get and set access, as well as method calls to Events. + Private methods and attributes are not exposed. + """ + value2: int + + def __init__(self): + self._private_attribute = 0 + self._private_property = 2 + self.value1 = 3 + self.value2 = 1 + + @property + def value1(self): + return self._private_property + + @property + def readonly_property(self): + return -1 + + @value1.setter + def value1(self, value): + self._private_property = value + + def public_method(self, x): + return self._private_method(x) + + def _private_method(self, x): + return x + self.value1 + self.value2 + +@pytest.fixture +def engine(): + e = ExecutionEngine() + yield e + e.shutdown() + +def verify_behavior(obj): + """Test the non-wrapped object""" + with pytest.raises(AttributeError): + obj.readonly_property = 0 # noqa property cannot be set + obj.value1 = 28 + obj.value2 = 29 + assert obj.value1 == 28 + assert obj.value2 == 29 + assert obj.readonly_property == -1 + result = obj.public_method(4) + if isinstance(result, ExecutionFuture): + result = result.await_execution() + assert result == 28 + 29 + 4 + + +def test_bare(): + verify_behavior(TestObject()) + +def test_wrapping(engine): + wrapper = engine.register("object1", TestObject()) + with pytest.raises(AttributeError): + wrapper.non_existing_property = 0 + verify_behavior(wrapper) + engine["object1"].value1 = 7 + assert wrapper.value1 == 7 + +def test_device_base_class(engine): + class T(TestObject, Device): + def __init__(self, _name, _engine): + super().__init__() + + device = T("object1", engine) + assert engine["object1"] is device + verify_behavior(device) + + +def test_openwfs(): + img = np.zeros((1000, 1000), dtype=np.int16) + cam = Camera(StaticSource(img), analog_max=None) + engine = ExecutionEngine() + wrapper = engine.register("camera1", cam) + future = wrapper.read() + engine.shutdown() + frame = future.await_execution() + assert frame.shape == img.shape + assert np.all(frame == img) + diff --git a/src/exengine/kernel/test/test_notifications.py b/src/exengine/kernel/test/test_notifications.py index ebcc12d..1514152 100644 --- a/src/exengine/kernel/test/test_notifications.py +++ b/src/exengine/kernel/test/test_notifications.py @@ -32,13 +32,12 @@ def mock_storage(): @pytest.fixture def mock_execution_engine(monkeypatch): mock_engine = Mock(spec=ExecutionEngine) - monkeypatch.setattr(ExecutionEngine, 'get_instance', lambda: mock_engine) return mock_engine @pytest.fixture def data_handler(mock_storage, mock_execution_engine): - return DataHandler(mock_storage, _executor=mock_execution_engine) + return DataHandler(engine=mock_execution_engine, storage=mock_storage) def test_notification_types_inheritance(): diff --git a/src/exengine/kernel/test/test_queue.py b/src/exengine/kernel/test/test_queue.py new file mode 100644 index 0000000..86630e4 --- /dev/null +++ b/src/exengine/kernel/test/test_queue.py @@ -0,0 +1,14 @@ +from exengine.kernel.queue import PriorityQueue + + +def test_priority_queue(): + q = PriorityQueue() + q.put((1, "first")) + q.put((1, "second")) + q.put((1, "third")) + q.put((0, "priority")) + assert q.get() == (0, "priority") + assert q.get() == (1, "first") + assert q.get() == (1, "second") + assert q.get() == (1, "third") + assert q.empty() \ No newline at end of file