Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored Device wrapping #37

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ Icon
venv
build/
dist/
*.idea
*.egg-info
*.egg-info test
*.pyc
Expand Down
39 changes: 19 additions & 20 deletions src/exengine/integration_tests/test_preferred_thread_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ class DecoratedEvent(ThreadRecordingEvent):

class TestDevice(Device):

def __init__(self, name):
super().__init__(name, no_executor_attrs=('_attribute', 'set_attribute_thread',
'get_attribute_thread', 'regular_method_thread',
'decorated_method_thread'))
def __init__(self, _name, _engine):
self._attribute = 123
self.set_attribute_thread = "Not set"
self.regular_method_thread = "Not set"
self.decorated_method_thread = "Not set"

@property
def attribute(self):
Expand All @@ -48,11 +48,11 @@ def decorated_method(self):
@on_thread("CustomDeviceThread")
class CustomThreadTestDevice(Device):

def __init__(self, name):
super().__init__(name, no_executor_attrs=('_attribute',
'set_attribute_thread', 'get_attribute_thread',
'regular_method_thread', 'decorated_method_thread'))
def __init__(self, _name, _engine):
self._attribute = 123
self.get_attribute_thread = "Not set"
self.set_attribute_thread = "Not set"
self.regular_method_thread = "Not set"

@property
def attribute(self):
Expand Down Expand Up @@ -103,39 +103,40 @@ def test_device_attribute_access(engine):
"""
Test that device attribute access runs on the main thread when nothing else specified.
"""
device = TestDevice("TestDevice")
device = TestDevice("TestDevice", engine)
device.attribute = 'something'
read_back = device.attribute
assert device.set_attribute_thread == _MAIN_THREAD_NAME

def test_device_regular_method_access(engine):
"""
Test that device method access runs on the main thread when nothing else specified.
"""
device = TestDevice("TestDevice")
device = TestDevice("TestDevice", engine)
device.regular_method()
assert device.regular_method_thread == _MAIN_THREAD_NAME

def test_device_decorated_method_access(engine):
"""
Test that device method access runs on the main thread when nothing else specified.
"""
device = TestDevice("TestDevice")
device = TestDevice("TestDevice", engine)
device.decorated_method()
assert device.decorated_method_thread == "CustomMethodThread"

def test_custom_thread_device_attribute_access(engine):
"""
Test that device attribute access runs on the custom thread when specified.
"""
custom_device = CustomThreadTestDevice("CustomDevice")
custom_device = CustomThreadTestDevice( "CustomDevice", engine)
custom_device.attribute = 'something'
assert custom_device.set_attribute_thread == "CustomDeviceThread"

def test_custom_thread_device_property_access(engine):
"""
Test that device property access runs on the custom thread when specified.
"""
custom_device = CustomThreadTestDevice("CustomDevice")
custom_device = CustomThreadTestDevice("CustomDevice", engine)
custom_device.attribute = 'something'
assert custom_device.set_attribute_thread == "CustomDeviceThread"

Expand All @@ -145,8 +146,7 @@ def test_custom_thread_device_property_access(engine):

@on_thread("OuterThread")
class OuterThreadDevice(Device):
def __init__(self, name, inner_device):
super().__init__(name)
def __init__(self, _name, _engine, inner_device):
self.inner_device = inner_device
self.outer_thread = None

Expand All @@ -157,8 +157,7 @@ def outer_method(self):

@on_thread("InnerThread")
class InnerThreadDevice(Device):
def __init__(self, name):
super().__init__(name)
def __init__(self, _name, _engine):
self.inner_thread = None

def inner_method(self):
Expand All @@ -170,8 +169,8 @@ def test_nested_thread_switch(engine):
Test that nested calls to methods with different thread specifications
result in correct thread switches at each level.
"""
inner_device = InnerThreadDevice("InnerDevice")
outer_device = OuterThreadDevice("OuterDevice", inner_device)
inner_device = InnerThreadDevice("InnerDevice", engine)
outer_device = OuterThreadDevice("OuterDevice", engine, inner_device)

class OuterEvent(ExecutorEvent):
def execute(self):
Expand Down Expand Up @@ -209,7 +208,7 @@ def test_multiple_decorators(engine):
"""
Test that the thread decorator works correctly when combined with other decorators.
"""
device = MultiDecoratedDevice("MultiDevice")
device = MultiDecoratedDevice("MultiDevice", engine)

class MultiEvent(ExecutorEvent):
def execute(self):
Expand Down
13 changes: 4 additions & 9 deletions src/exengine/kernel/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic.types import JsonValue
from dataclasses import dataclass

from .executor import ExecutionEngine
from .notification_base import DataStoredNotification
from .data_coords import DataCoordinates
from .data_storage_base import DataStorage
Expand Down Expand Up @@ -49,17 +50,11 @@ class DataHandler:
# This class must create at least one additional thread (the saving thread)
# and may create another for processing data

def __init__(self, storage: DataStorage,
def __init__(self, engine: ExecutionEngine, storage: DataStorage,
process_function: Callable[[DataCoordinates, npt.NDArray["Any"], JsonValue],
Optional[Union[DataCoordinates, npt.NDArray["Any"], JsonValue,
Tuple[DataCoordinates, npt.NDArray["Any"], JsonValue]]]] = None,
_executor=None):
# delayed import to avoid circular imports
if _executor is None:
from .executor import ExecutionEngine
self._engine = ExecutionEngine.get_instance()
else:
self._engine = _executor
Tuple[DataCoordinates, npt.NDArray["Any"], JsonValue]]]] = None):
self._engine = engine
self._storage = storage
self._process_function = process_function
self._intake_queue = _PeekableQueue()
Expand Down
Loading
Loading