Skip to content

Commit

Permalink
Work on mac (#53)
Browse files Browse the repository at this point in the history
Improve support for the "spawn" start method in multiprocessing. This makes daisy more accessible to MacOS and Windows where "fork" is less convenient to use due to security concerns. To do this we use dill since dill can pickle a larger variety of objects than the default pickle library.

Commits:
* simplify tests

use pytest functions instead of UnitTest classes
Remove unnecessary lambdas (`lambda b: func(b)` is equivalent to `func`)

* remove lambdas and double underscore methods

these cannot be easily pickled so they can't be used with start_method="spawn"

* improve task `process_func` signature parsing

We don't care about kwargs with defaults, we only care about mandatory args.
1 mandatory arg means daisy knows to provide a thin worker wrapper and simply pass the blocks to your
process function. 0 mandatory args means daisy can safely assume the function handles worker creation.

* use dill to store the spawn function as bytes

when creating a worker, we pass in a spawn function. Dill is an improved version of pickle so dill can serialize many
more functions than pickle can. Thus we use dill to serialize the spawn function to bytes, which can then be serialized/deserialized by pickle when
spawning subprocesses. Then when the worker needs to run the spawn function, we can use dill to deserialize the bytes into the desired function.
  • Loading branch information
pattonw authored Jun 13, 2024
1 parent a745dd5 commit ad974c1
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 75 deletions.
17 changes: 12 additions & 5 deletions daisy/task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .client import Client
from inspect import signature
from inspect import getfullargspec


class Task:
Expand Down Expand Up @@ -163,15 +163,22 @@ def __init__(
if init_callback_fn is not None:
self.init_callback_fn = init_callback_fn
else:
self.init_callback_fn = lambda context: None
self.init_callback_fn = self._default_init

if len(signature(process_function).parameters) == 0:
args = getfullargspec(process_function).args
if len(args) == 0:
# spawn function
self.spawn_worker_function = process_function
elif len(args) == 1:
# process block function
self.spawn_worker_function = self._process_blocks
else:
self.spawn_worker_function = lambda: self._process_blocks()
raise ValueError(f"daisy does not know what to pass into args: {args}")

def _process_blocks(self):
def _default_init(self, context):
pass

def _process_blocks(self):
client = Client()
while True:
with client.acquire_block() as block:
Expand Down
14 changes: 12 additions & 2 deletions daisy/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import multiprocessing
import os
import queue
import dill

logger = logging.getLogger(__name__)

Expand All @@ -27,6 +28,7 @@ class Worker:
"""

__next_id = multiprocessing.Value("L")
_spawn_function = None

@staticmethod
def get_next_id():
Expand All @@ -49,14 +51,22 @@ def __init__(self, spawn_function, context=None, error_queue=None):

self.start()

@property
def spawn_function(self):
return dill.loads(self._spawn_function)

@spawn_function.setter
def spawn_function(self, value):
self._spawn_function = dill.dumps(value)

def start(self):
"""Start this worker. Note that workers are automatically started when
created. Use this function to re-start a stopped worker."""

if self.process is not None:
return

self.process = multiprocessing.Process(target=lambda: self.__spawn_wrapper())
self.process = multiprocessing.Process(target=self._spawn_wrapper)
self.process.start()

def stop(self):
Expand All @@ -74,7 +84,7 @@ def stop(self):
logger.debug("%s terminated", self)
self.process = None

def __spawn_wrapper(self):
def _spawn_wrapper(self):
"""Thin wrapper around the user-specified spawn function to set
environment variables, redirect output, and to capture exceptions."""

Expand Down
10 changes: 5 additions & 5 deletions daisy/worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ def set_num_workers(self, num_workers):
logger.debug("current number of workers: %d", len(self.workers))

if diff > 0:
self.__start_workers(diff)
self._start_workers(diff)
elif diff < 0:
self.__stop_workers(-diff)
self._stop_workers(-diff)

def inc_num_workers(self, num_workers):
self.__start_workers(num_workers)
self._start_workers(num_workers)

def stop(self, worker_id=None):
"""Stop all current workers in this pool (``worker_id == None``) or a
Expand Down Expand Up @@ -104,7 +104,7 @@ def check_for_errors(self):
except queue.Empty:
pass

def __start_workers(self, n):
def _start_workers(self, n):

logger.debug("starting %d new workers", n)
new_workers = [
Expand All @@ -113,7 +113,7 @@ def __start_workers(self, n):
]
self.workers.update({worker.worker_id: worker for worker in new_workers})

def __stop_workers(self, n):
def _stop_workers(self, n):

logger.debug("stopping %d workers", n)

Expand Down
2 changes: 1 addition & 1 deletion examples/batch_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def _prepare_task(
)

if check_fn is None:
check_fn = lambda b: self._default_check_fn(b)
check_fn = self._default_check_fn

if self.overwrite:
print("Dropping table %s" % self.db_id)
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"tqdm",
"funlib.math",
"funlib.geometry",
"dill",
]

[project.optional-dependencies]
Expand All @@ -46,5 +47,6 @@ module = [
"funlib.*",
"tqdm.*",
"pkg_resources.*",
"dill",
]
ignore_missing_imports = true
83 changes: 40 additions & 43 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,51 +5,48 @@
from daisy.tcp import TCPServer


class TestClient(unittest.TestCase):
def run_test_server(block, conn):
server = TCPServer()
conn.send(server.address)

def run_test_server(self, block, conn):
server = TCPServer()
conn.send(server.address)
# handle first acquire_block message
message = None
for i in range(10):
message = server.get_message(timeout=1)
if message:
break
if not message:
raise Exception("SERVER COULDN'T GET MESSAGE")
try:
assert isinstance(message, AcquireBlock)
message.stream.send_message(SendBlock(block))
except Exception as e:
message.stream.send_message(ExceptionMessage(e))

# handle first acquire_block message
message = None
for i in range(10):
message = server.get_message(timeout=1)
if message:
break
if not message:
raise Exception("SERVER COULDN'T GET MESSAGE")
try:
self.assertTrue(isinstance(message, AcquireBlock))
message.stream.send_message(SendBlock(block))
except Exception as e:
message.stream.send_message(ExceptionMessage(e))
# handle return_block message
message = server.get_message(timeout=1)
try:
assert isinstance(message, ReleaseBlock)
assert message.block.status == daisy.BlockStatus.SUCCESS
except Exception as e:
message.stream.send_message(ExceptionMessage(e))
conn.send(1)
conn.close()

# handle return_block message
message = server.get_message(timeout=1)
try:
self.assertTrue(isinstance(message, ReleaseBlock))
self.assertTrue(message.block.status == daisy.BlockStatus.SUCCESS)
except Exception as e:
message.stream.send_message(ExceptionMessage(e))
conn.send(1)
conn.close()

def test_basic(self):
roi = daisy.Roi((0, 0, 0), (10, 10, 10))
task_id = 1
block = daisy.Block(roi, roi, roi, block_id=1, task_id=task_id)
parent_conn, child_conn = mp.Pipe()
server_process = mp.Process(
target=self.run_test_server, args=(block, child_conn)
)
server_process.start()
host, port = parent_conn.recv()
context = daisy.Context(hostname=host, port=port, task_id=task_id, worker_id=1)
client = daisy.Client(context=context)
with client.acquire_block() as block:
block.status = daisy.BlockStatus.SUCCESS
def test_basic():
roi = daisy.Roi((0, 0, 0), (10, 10, 10))
task_id = 1
block = daisy.Block(roi, roi, roi, block_id=1, task_id=task_id)
parent_conn, child_conn = mp.Pipe()
server_process = mp.Process(target=run_test_server, args=(block, child_conn))
server_process.start()
host, port = parent_conn.recv()
context = daisy.Context(hostname=host, port=port, task_id=task_id, worker_id=1)
client = daisy.Client(context=context)
with client.acquire_block() as block:
block.status = daisy.BlockStatus.SUCCESS

success = parent_conn.recv()
server_process.join()
self.assertTrue(success)
success = parent_conn.recv()
server_process.join()
assert success
37 changes: 18 additions & 19 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,25 @@
logging.basicConfig(level=logging.DEBUG)


class TestServer(unittest.TestCase):
def process_block(block):
print("Processing block %s" % block)

def test_basic(self):

task = daisy.Task(
"test_server_task",
total_roi=daisy.Roi((0,), (100,)),
read_roi=daisy.Roi((0,), (10,)),
write_roi=daisy.Roi((1,), (8,)),
process_function=lambda b: self.process_block(b),
check_function=None,
read_write_conflict=True,
fit="valid",
num_workers=1,
max_retries=2,
timeout=None,
)
def test_basic():

server = daisy.Server()
server.run_blockwise([task])
task = daisy.Task(
"test_server_task",
total_roi=daisy.Roi((0,), (100,)),
read_roi=daisy.Roi((0,), (10,)),
write_roi=daisy.Roi((1,), (8,)),
process_function=process_block,
check_function=None,
read_write_conflict=True,
fit="valid",
num_workers=1,
max_retries=2,
timeout=None,
)

def process_block(self, block):
print("Processing block %s" % block)
server = daisy.Server()
server.run_blockwise([task])

0 comments on commit ad974c1

Please sign in to comment.