From 32a27156bd0fcc4b7c69f3509f453b833a683919 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Fri, 23 Feb 2024 20:37:39 +0000 Subject: [PATCH] =?UTF-8?q?style:=20=F0=9F=8E=A8=20Black=20formatted.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- daisy/__init__.py | 30 ++-- daisy/block.py | 13 +- daisy/block_bookkeeper.py | 30 ++-- daisy/blocks.py | 73 +++------- daisy/cl_monitor.py | 49 ++++--- daisy/client.py | 35 ++--- daisy/context.py | 18 ++- daisy/convenience.py | 18 +-- daisy/coordinate.py | 2 +- daisy/dependency_graph.py | 81 ++++------- daisy/freezable.py | 3 +- daisy/logging.py | 30 ++-- daisy/messages/message.py | 4 +- daisy/ready_surface.py | 4 +- daisy/roi.py | 2 +- daisy/scheduler.py | 22 ++- daisy/server.py | 71 ++++------ daisy/task.py | 5 +- daisy/task_worker_pools.py | 17 +-- daisy/tcp/internal_messages.py | 15 +- daisy/tcp/io_looper.py | 8 +- daisy/tcp/tcp_client.py | 30 ++-- daisy/tcp/tcp_message.py | 5 +- daisy/tcp/tcp_server.py | 25 ++-- daisy/tcp/tcp_stream.py | 18 +-- daisy/worker.py | 31 ++-- daisy/worker_pool.py | 21 ++- docs/conf.py | 61 ++++---- examples/basic_workflow.py | 63 ++++++--- examples/batch_task.py | 218 +++++++++++++++++------------ examples/chaining_example.py | 54 ++++--- examples/gaussian_smoothing1.py | 116 ++++++++------- examples/gaussian_smoothing2.py | 110 +++++++++------ examples/hdf_to_zarr.py | 98 +++++++------ examples/minimal_example/server.py | 17 +-- examples/visualize.py | 105 +++++++------- tests/test_client.py | 18 +-- tests/test_scheduler.py | 5 +- tests/test_server.py | 24 ++-- tests/tmpdir_test.py | 21 +-- 40 files changed, 793 insertions(+), 777 deletions(-) diff --git a/daisy/__init__.py b/daisy/__init__.py index 4aa01b07..d650b344 100644 --- a/daisy/__init__.py +++ b/daisy/__init__.py @@ -1,16 +1,16 @@ from __future__ import absolute_import -from .block import Block, BlockStatus # noqa -from .blocks import expand_roi_to_grid # noqa -from .blocks import expand_write_roi_to_grid # noqa -from .client import Client # noqa -from .context import Context # noqa -from .convenience import run_blockwise # noqa -from .coordinate import Coordinate # noqa -from .dependency_graph import DependencyGraph, BlockwiseDependencyGraph # noqa -from .logging import get_worker_log_basename, redirect_stdouterr # noqa -from .roi import Roi # noqa -from .scheduler import Scheduler # noqa -from .server import Server # noqa -from .task import Task # noqa -from .worker import Worker # noqa -from .worker_pool import WorkerPool # noqa +from .block import Block, BlockStatus # noqa +from .blocks import expand_roi_to_grid # noqa +from .blocks import expand_write_roi_to_grid # noqa +from .client import Client # noqa +from .context import Context # noqa +from .convenience import run_blockwise # noqa +from .coordinate import Coordinate # noqa +from .dependency_graph import DependencyGraph, BlockwiseDependencyGraph # noqa +from .logging import get_worker_log_basename, redirect_stdouterr # noqa +from .roi import Roi # noqa +from .scheduler import Scheduler # noqa +from .server import Server # noqa +from .task import Task # noqa +from .worker import Worker # noqa +from .worker_pool import WorkerPool # noqa diff --git a/daisy/block.py b/daisy/block.py index 48a3e47a..8e4703d5 100644 --- a/daisy/block.py +++ b/daisy/block.py @@ -66,13 +66,8 @@ class Block(Freezable): The id of the Task that this block belongs to. Defaults to None. """ - def __init__( - self, - total_roi, - read_roi, - write_roi, - block_id=None, - task_id=None): + + def __init__(self, total_roi, read_roi, write_roi, block_id=None, task_id=None): self.read_roi = read_roi self.write_roi = write_roi @@ -92,9 +87,7 @@ def copy(self): return copy.deepcopy(self) def __compute_block_id(self, total_roi, write_roi, shift=None): - block_index = ( - write_roi.offset - total_roi.offset - ) / write_roi.shape + block_index = (write_roi.offset - total_roi.offset) / write_roi.shape # block_id will be the cantor number for this block index block_id = int(cantor_number(block_index)) diff --git a/daisy/block_bookkeeper.py b/daisy/block_bookkeeper.py index 173e8565..fcd30b94 100644 --- a/daisy/block_bookkeeper.py +++ b/daisy/block_bookkeeper.py @@ -16,37 +16,39 @@ def __init__(self, processing_timeout=None): self.sent_blocks = {} def notify_block_sent(self, block, stream): - '''Notify the bookkeeper that a block has been sent to a client (i.e., - a stream to the client).''' + """Notify the bookkeeper that a block has been sent to a client (i.e., + a stream to the client).""" - assert block.block_id not in self.sent_blocks, \ - f"Attempted to send block {block}, although it is already being " \ + assert block.block_id not in self.sent_blocks, ( + f"Attempted to send block {block}, although it is already being " f"processed by {self.sent_blocks[block.block_id].stream}" + ) self.sent_blocks[block.block_id] = BlockLog(block, stream) def notify_block_returned(self, block, stream): - '''Notify the bookkeeper that a block was returned.''' + """Notify the bookkeeper that a block was returned.""" - assert block.block_id in self.sent_blocks, \ - f"Block {block} was returned by {stream}, but is not in list " \ + assert block.block_id in self.sent_blocks, ( + f"Block {block} was returned by {stream}, but is not in list " "of sent blocks" + ) log = self.sent_blocks[block.block_id] block.started_processing = log.time_sent block.stopped_processing = time.time() - assert stream == log.stream, \ - f"Block {block} was returned by {stream}, but was sent to" \ - f"{log.stream}" + assert stream == log.stream, ( + f"Block {block} was returned by {stream}, but was sent to" f"{log.stream}" + ) del self.sent_blocks[block.block_id] def is_valid_return(self, block, stream): - '''Check whether the block from the given client (i.e., stream) is + """Check whether the block from the given client (i.e., stream) is expected to be returned from this client. This is to avoid double returning blocks that have already been returned as lost blocks, but - still come back from the client due to race conditions.''' + still come back from the client due to race conditions.""" # block was never sent or already returned if block.block_id not in self.sent_blocks: @@ -59,10 +61,10 @@ def is_valid_return(self, block, stream): return True def get_lost_blocks(self): - '''Return a list of blocks that were sent and are lost, either because + """Return a list of blocks that were sent and are lost, either because the stream to the client closed or the processing timed out. Those blocks are removed from the sent-list with the call of this - function.''' + function.""" lost_block_ids = [] for block_id, log in self.sent_blocks.items(): diff --git a/daisy/blocks.py b/daisy/blocks.py index b3dad3aa..a35574a1 100644 --- a/daisy/blocks.py +++ b/daisy/blocks.py @@ -138,16 +138,11 @@ def create_dependency_graph(self): ] # TODO: can we do this part lazily? This might be a lot of # Coordinates - block_offsets = [ - Coordinate(o) - for o in product(*block_dim_offsets)] + block_offsets = [Coordinate(o) for o in product(*block_dim_offsets)] # convert to global coordinates block_offsets = [ - o + ( - self.total_roi.get_begin() - - self.block_read_roi.get_begin() - ) + o + (self.total_roi.get_begin() - self.block_read_roi.get_begin()) for o in block_offsets ] @@ -173,12 +168,8 @@ def compute_level_stride(self): self.block_write_roi ), "Read ROI must contain write ROI." - context_ul = ( - self.block_write_roi.get_begin() - - self.block_read_roi.get_begin()) - context_lr = ( - self.block_read_roi.get_end() - - self.block_write_roi.get_end()) + context_ul = self.block_write_roi.get_begin() - self.block_read_roi.get_begin() + context_lr = self.block_read_roi.get_end() - self.block_write_roi.get_end() max_context = Coordinate( (max(ul, lr) for ul, lr in zip(context_ul, context_lr)) @@ -195,14 +186,14 @@ def compute_level_stride(self): # to avoid overlapping write ROIs, increase the stride to the next # multiple of write shape write_shape = self.block_write_roi.get_shape() - level_stride = Coordinate(( - ((level - 1) // w + 1) * w - for level, w in zip(min_level_stride, write_shape) - )) + level_stride = Coordinate( + ( + ((level - 1) // w + 1) * w + for level, w in zip(min_level_stride, write_shape) + ) + ) - logger.debug( - "final level stride (multiples of write size) is %s", - level_stride) + logger.debug("final level stride (multiples of write size) is %s", level_stride) return level_stride @@ -219,26 +210,16 @@ def compute_level_offsets(self): ) dim_offsets = [ - range(0, e, step) - for e, step in zip(self.level_stride, write_stride) + range(0, e, step) for e, step in zip(self.level_stride, write_stride) ] - level_offsets = list( - reversed([ - Coordinate(o) - for o in product(*dim_offsets) - ]) - ) + level_offsets = list(reversed([Coordinate(o) for o in product(*dim_offsets)])) logger.debug("level offsets: %s", level_offsets) return level_offsets - def get_conflict_offsets( - self, - level_offset, - prev_level_offset, - level_stride): + def get_conflict_offsets(self, level_offset, prev_level_offset, level_stride): """Get the offsets to all previous level blocks that are in conflict with the current level blocks.""" @@ -250,13 +231,8 @@ def get_conflict_offsets( for op, ls in zip(offset_to_prev, level_stride) ] - conflict_offsets = [ - Coordinate(o) - for o in product(*conflict_dim_offsets) - ] - logger.debug( - "conflict offsets to previous level: %s", - conflict_offsets) + conflict_offsets = [Coordinate(o) for o in product(*conflict_dim_offsets)] + logger.debug("conflict offsets to previous level: %s", conflict_offsets) return conflict_offsets @@ -264,8 +240,7 @@ def enumerate_dependencies(self, conflict_offsets, block_offsets): inclusion_criteria = { "valid": lambda b: self.total_roi.contains(b.read_roi), - "overhang": lambda b: self.total_roi.contains( - b.write_roi.get_begin()), + "overhang": lambda b: self.total_roi.contains(b.write_roi.get_begin()), "shrink": lambda b: self.shrink_possible(b), }[self.fit] @@ -388,10 +363,7 @@ def get_subgraph_blocks(self, sub_roi): def expand_roi_to_grid(sub_roi, total_roi, read_roi, write_roi): """Expands given roi so that its write region is aligned to write_roi""" - offset = ( - write_roi.get_begin() + - total_roi.get_begin() - - read_roi.get_begin()) + offset = write_roi.get_begin() + total_roi.get_begin() - read_roi.get_begin() begin = sub_roi.get_begin() - offset end = sub_roi.get_end() - offset @@ -405,10 +377,7 @@ def expand_roi_to_grid(sub_roi, total_roi, read_roi, write_roi): def expand_request_roi_to_grid(req_roi, total_roi, read_roi, write_roi): """Expands given roi so that its write region is aligned to write_roi""" - offset = ( - write_roi.get_begin() + - total_roi.get_begin() - - read_roi.get_begin()) + offset = write_roi.get_begin() + total_roi.get_begin() - read_roi.get_begin() begin = req_roi.get_begin() - offset end = req_roi.get_end() - offset @@ -430,7 +399,5 @@ def expand_write_roi_to_grid(roi, write_roi): -(-roi.get_end() // write_roi.get_shape()), ) # `ceildiv` - roi = Roi( - roi[0] * write_roi.get_shape(), - (roi[1] - roi[0]) * write_roi.get_shape()) + roi = Roi(roi[0] * write_roi.get_shape(), (roi[1] - roi[0]) * write_roi.get_shape()) return roi diff --git a/daisy/cl_monitor.py b/daisy/cl_monitor.py index dddffc72..99cfddba 100644 --- a/daisy/cl_monitor.py +++ b/daisy/cl_monitor.py @@ -6,12 +6,12 @@ class TqdmLoggingHandler: - '''A logging handler that uses ``tqdm.tqdm.write`` in ``emit()``, such that + """A logging handler that uses ``tqdm.tqdm.write`` in ``emit()``, such that logging doesn't interfere with tqdm's progress bar. Heavily inspired by the fantastic https://github.com/EpicWink/tqdm-logging-wrapper/ - ''' + """ def __init__(self, handler): self.handler = handler @@ -42,8 +42,10 @@ def __init__(self, block, exception, worker_id): self.worker_id = worker_id def __repr__(self): - return f"block {self.block.block_id[1]} in worker " \ + return ( + f"block {self.block.block_id[1]} in worker " f"{self.worker_id} with exception {repr(self.exception)}" + ) class CLMonitor(ServerObserver): @@ -56,13 +58,13 @@ def __init__(self, server): self._wrap_logging_handlers() def _wrap_logging_handlers(self): - '''This adds a TqdmLoggingHandler around each logging handler that has + """This adds a TqdmLoggingHandler around each logging handler that has a TTY stream attached to it, so that logging doesn't interfere with the progress bar. Heavily inspired by the fantastic https://github.com/EpicWink/tqdm-logging-wrapper/ - ''' + """ logger = logging.root for i in range(len(logger.handlers)): @@ -87,7 +89,8 @@ def on_block_failure(self, block, exception, context): task_id = block.block_id[0] self.summaries[task_id].block_failures.append( - BlockFailure(block, exception, context['worker_id'])) + BlockFailure(block, exception, context["worker_id"]) + ) def on_task_start(self, task_id, task_state): @@ -130,8 +133,10 @@ def on_server_exit(self): print() state = summary.state print(f" num blocks : {state.total_block_count}") - print(f" completed ✔: {state.completed_count} " - f"(skipped {state.skipped_count})") + print( + f" completed ✔: {state.completed_count} " + f"(skipped {state.skipped_count})" + ) print(f" failed ✗: {state.failed_count}") print(f" orphaned ∅: {state.orphaned_count}") print() @@ -151,8 +156,8 @@ def on_server_exit(self): print() for block_failure in summary.block_failures[:10]: log_basename = daisy_logging.get_worker_log_basename( - block_failure.worker_id, - block_failure.block.block_id[0]) + block_failure.worker_id, block_failure.block.block_id[0] + ) print(f" {log_basename}.err / .out") if num_block_failures > 10: print(" ...") @@ -165,18 +170,18 @@ def _update_state(self, task_id, task_state): if task_id not in self.progresses: total = task_state.total_block_count self.progresses[task_id] = tqdm_auto( - total=total, - desc=task_id + " ▶", - unit='blocks', - leave=True) - - self.progresses[task_id].set_postfix({ - '⧗': task_state.pending_count, - '▶': task_state.processing_count, - '✔': task_state.completed_count, - '✗': task_state.failed_count, - '∅': task_state.orphaned_count - }) + total=total, desc=task_id + " ▶", unit="blocks", leave=True + ) + + self.progresses[task_id].set_postfix( + { + "⧗": task_state.pending_count, + "▶": task_state.processing_count, + "✔": task_state.completed_count, + "✗": task_state.failed_count, + "∅": task_state.orphaned_count, + } + ) completed = task_state.completed_count delta = completed - self.progresses[task_id].n diff --git a/daisy/client.py b/daisy/client.py index 55155939..bfc9f2ca 100644 --- a/daisy/client.py +++ b/daisy/client.py @@ -6,7 +6,8 @@ ReleaseBlock, RequestShutdown, SendBlock, - UnexpectedMessage) + UnexpectedMessage, +) from contextlib import contextmanager from daisy.tcp import TCPClient, StreamClosedError import logging @@ -14,8 +15,8 @@ logger = logging.getLogger(__name__) -class Client(): - '''Client code that runs on a remote worker providing task management +class Client: + """Client code that runs on a remote worker providing task management API for user code. It communicates with the scheduler through TCP/IP. Scheduler IP address, port, and other configurations are typically @@ -35,12 +36,10 @@ def main(): break blockwise_process(block) block.state = BlockStatus.SUCCESS # (or FAILED) - ''' + """ - def __init__( - self, - context=None): - '''Initialize a client and connect to the server. + def __init__(self, context=None): + """Initialize a client and connect to the server. Args: @@ -50,24 +49,24 @@ def __init__( given, the context will be read from environment variable ``DAISY_CONTEXT``. - ''' + """ logger.debug("Client init") self.context = context if self.context is None: self.context = Context.from_env() logger.debug("Client context: %s", self.context) - self.host = self.context['hostname'] - self.port = int(self.context['port']) - self.worker_id = int(self.context['worker_id']) - self.task_id = self.context['task_id'] + self.host = self.context["hostname"] + self.port = int(self.context["port"]) + self.worker_id = int(self.context["worker_id"]) + self.task_id = self.context["task_id"] # Make TCP Connection self.tcp_client = TCPClient(self.host, self.port) @contextmanager def acquire_block(self): - '''API for client to get a new block.''' + """API for client to get a new block.""" self.tcp_client.send_message(AcquireBlock(self.task_id)) message = None try: @@ -89,12 +88,8 @@ def acquire_block(self): block.status = BlockStatus.SUCCESS except Exception as e: block.status = BlockStatus.FAILED - self.tcp_client.send_message( - BlockFailed(e, block, self.context)) - logger.exception( - "Block %s failed in worker %d", - block, - self.worker_id) + self.tcp_client.send_message(BlockFailed(e, block, self.context)) + logger.exception("Block %s failed in worker %d", block, self.worker_id) finally: # if we somehow got here without setting the block status to # "SUCCESS" (e.g., through KeyboardInterrupt), we assume the diff --git a/daisy/context.py b/daisy/context.py index 35f20c02..82a52a1f 100644 --- a/daisy/context.py +++ b/daisy/context.py @@ -5,9 +5,9 @@ logger = logging.getLogger(__name__) -class Context(): +class Context: - ENV_VARIABLE = 'DAISY_CONTEXT' + ENV_VARIABLE = "DAISY_CONTEXT" def __init__(self, **kwargs): @@ -19,16 +19,16 @@ def copy(self): def to_env(self): - return ':'.join('%s=%s' % (k, v) for k, v in self.__dict.items()) + return ":".join("%s=%s" % (k, v) for k, v in self.__dict.items()) def __setitem__(self, k, v): k = str(k) v = str(v) - if '=' in k or ':' in k: + if "=" in k or ":" in k: raise RuntimeError("Context variables must not contain = or :.") - if '=' in v or ':' in v: + if "=" in v or ":" in v: raise RuntimeError("Context values must not contain = or :.") self.__dict[k] = v @@ -50,19 +50,17 @@ def from_env(): try: - tokens = os.environ[Context.ENV_VARIABLE].split(':') + tokens = os.environ[Context.ENV_VARIABLE].split(":") except KeyError: - logger.error( - "%s environment variable not found!", - Context.ENV_VARIABLE) + logger.error("%s environment variable not found!", Context.ENV_VARIABLE) raise context = Context() for token in tokens: - k, v = token.split('=') + k, v = token.split("=") context[k] = v return context diff --git a/daisy/convenience.py b/daisy/convenience.py index 767fb057..344e6c6c 100644 --- a/daisy/convenience.py +++ b/daisy/convenience.py @@ -6,17 +6,17 @@ def run_blockwise(tasks): - '''Schedule and run the given tasks. + """Schedule and run the given tasks. - Args: - list_of_tasks: - The tasks to schedule over. + Args: + list_of_tasks: + The tasks to schedule over. - Return: - bool: - `True` if all blocks in the given `tasks` were successfully - run, else `False` - ''' + Return: + bool: + `True` if all blocks in the given `tasks` were successfully + run, else `False` + """ task_ids = set() all_tasks = [] while len(tasks) > 0: diff --git a/daisy/coordinate.py b/daisy/coordinate.py index a05770bf..c37998ea 100644 --- a/daisy/coordinate.py +++ b/daisy/coordinate.py @@ -1 +1 @@ -from funlib.geometry import Coordinate # noqa \ No newline at end of file +from funlib.geometry import Coordinate # noqa diff --git a/daisy/dependency_graph.py b/daisy/dependency_graph.py index b12475a2..29fb61b3 100644 --- a/daisy/dependency_graph.py +++ b/daisy/dependency_graph.py @@ -160,9 +160,7 @@ def inclusion_criteria(self): # TODO: Can't we remove this entirely by pre computing the write_roi inclusion_criteria = { "valid": lambda b: self.total_write_roi.contains(b.write_roi), - "overhang": lambda b: self.total_write_roi.contains( - b.write_roi.begin - ), + "overhang": lambda b: self.total_write_roi.contains(b.write_roi.begin), "shrink": lambda b: self.shrink_possible(b), }[self.fit] return inclusion_criteria @@ -202,7 +200,8 @@ def _num_level_blocks(self, level): level_offset, self._level_stride, num_blocks, - axis_blocks) + axis_blocks, + ) return num_blocks @@ -230,9 +229,7 @@ def root_gen(self): def _block_offset(self, block): # The block offset is the offset of the read roi relative to total roi - block_offset = ( - block.read_roi.offset - - self.total_read_roi.offset) + block_offset = block.read_roi.offset - self.total_read_roi.offset return block_offset def _level(self, block): @@ -332,12 +329,8 @@ def compute_level_stride(self) -> Coordinate: self.block_write_roi ), "Read ROI must contain write ROI." - context_ul = ( - self.block_write_roi.begin - - self.block_read_roi.begin) - context_lr = ( - self.block_read_roi.end - - self.block_write_roi.end) + context_ul = self.block_write_roi.begin - self.block_read_roi.begin + context_lr = self.block_read_roi.end - self.block_write_roi.end max_context = Coordinate( (max(ul, lr) for ul, lr in zip(context_ul, context_lr)) @@ -354,10 +347,12 @@ def compute_level_stride(self) -> Coordinate: # to avoid overlapping write ROIs, increase the stride to the next # multiple of write shape write_shape = self.block_write_roi.shape - level_stride = Coordinate(( - ((level - 1) // w + 1) * w - for level, w in zip(min_level_stride, write_shape) - )) + level_stride = Coordinate( + ( + ((level - 1) // w + 1) * w + for level, w in zip(min_level_stride, write_shape) + ) + ) # Handle case where min_level_stride > total_write_roi. # This case leads to levels with no blocks in them. This makes @@ -373,11 +368,10 @@ def compute_level_stride(self) -> Coordinate: ) % self.block_write_roi.shape level_stride = Coordinate( - (min(a, b) for a, b in zip(level_stride, write_roi_shape))) + (min(a, b) for a, b in zip(level_stride, write_roi_shape)) + ) - logger.debug( - "final level stride (multiples of write size) is %s", - level_stride) + logger.debug("final level stride (multiples of write size) is %s", level_stride) return level_stride @@ -395,15 +389,10 @@ def compute_level_offsets(self) -> List[Coordinate]: ) dim_offsets = [ - range(0, e, step) - for e, step in zip(self._level_stride, write_stride) + range(0, e, step) for e, step in zip(self._level_stride, write_stride) ] - level_offsets = list( - reversed([ - Coordinate(o) - for o in product(*dim_offsets)]) - ) + level_offsets = list(reversed([Coordinate(o) for o in product(*dim_offsets)])) logger.debug("level offsets: %s", level_offsets) @@ -451,10 +440,7 @@ def _compute_level_block_offsets(self, level): block_offset = Coordinate(offset) # convert to global coordinates - block_offset += ( - self.total_read_roi.begin - - self.block_read_roi.begin - ) + block_offset += self.total_read_roi.begin - self.block_read_roi.begin yield block_offset def compute_level_block_offsets(self) -> List[List[Coordinate]]: @@ -466,17 +452,11 @@ def compute_level_block_offsets(self) -> List[List[Coordinate]]: for level in range(self.num_levels): - level_block_offsets.append( - list( - self._compute_level_block_offsets(level))) + level_block_offsets.append(list(self._compute_level_block_offsets(level))) return level_block_offsets - def get_conflict_offsets( - self, - level_offset, - prev_level_offset, - level_stride): + def get_conflict_offsets(self, level_offset, prev_level_offset, level_stride): """Get the offsets to all previous level blocks that are in conflict with the current level blocks.""" @@ -495,13 +475,8 @@ def get_offsets(op, ls): get_offsets(op, ls) for op, ls in zip(offset_to_prev, level_stride) ] - conflict_offsets = [ - Coordinate(o) - for o in product(*conflict_dim_offsets) - ] - logger.debug( - "conflict offsets to previous level: %s", - conflict_offsets) + conflict_offsets = [Coordinate(o) for o in product(*conflict_dim_offsets)] + logger.debug("conflict offsets to previous level: %s", conflict_offsets) return conflict_offsets @@ -536,7 +511,8 @@ def get_subgraph_blocks(self, sub_roi, read_roi=False): # only need to check if a blocks read_roi overlaps with sub_roi. # This is the same behavior as when we want write_roi overlap sub_roi = sub_roi.grow( - self.read_write_context[0], self.read_write_context[1]) + self.read_write_context[0], self.read_write_context[1] + ) # TODO: handle unsatisfiable sub_rois # i.e. sub_roi is outside of *total_write_roi @@ -690,8 +666,7 @@ def __enumerate_all_dependencies(self): "Block dependency %s is not found for task %s." % (upstream_block.block_id, task_id) ) - self._downstream[upstream_block.block_id].add( - block.block_id) + self._downstream[upstream_block.block_id].add(block.block_id) self._upstream[block.block_id].add(upstream_block.block_id) # enumerate all of the upstream / downstream dependencies @@ -713,7 +688,5 @@ def __enumerate_all_dependencies(self): "Block dependency %s is not found for task %s." % (upstream_block.block_id, task_id) ) - self._downstream[upstream_block.block_id].add( - block.block_id) - self._upstream[block.block_id].add( - upstream_block.block_id) + self._downstream[upstream_block.block_id].add(block.block_id) + self._upstream[block.block_id].add(upstream_block.block_id) diff --git a/daisy/freezable.py b/daisy/freezable.py index 93572868..5961a9e9 100644 --- a/daisy/freezable.py +++ b/daisy/freezable.py @@ -4,8 +4,7 @@ class Freezable(object): def __setattr__(self, key, value): if self.__isfrozen and not hasattr(self, key): - raise TypeError( - "%r is frozen, you can't add attributes to it" % self) + raise TypeError("%r is frozen, you can't add attributes to it" % self) object.__setattr__(self, key, value) def freeze(self): diff --git a/daisy/logging.py b/daisy/logging.py index 8a917b58..504c6792 100644 --- a/daisy/logging.py +++ b/daisy/logging.py @@ -4,16 +4,16 @@ # default log dir -LOG_BASEDIR = Path('./daisy_logs') +LOG_BASEDIR = Path("./daisy_logs") def set_log_basedir(path): - '''Set the base directory for logging (indivudal worker logs and detailed + """Set the base directory for logging (indivudal worker logs and detailed task summaries). If set to ``None``, all logging will be shown on the command line (which can get very messy). Default is ``./daisy_logs``. - ''' + """ global LOG_BASEDIR @@ -24,7 +24,7 @@ def set_log_basedir(path): def get_worker_log_basename(worker_id, task_id=None): - '''Get the basename of log files for individual workers.''' + """Get the basename of log files for individual workers.""" if LOG_BASEDIR is None: return None @@ -32,17 +32,17 @@ def get_worker_log_basename(worker_id, task_id=None): basename = LOG_BASEDIR if task_id is not None: basename /= task_id - basename /= f'worker_{worker_id}' + basename /= f"worker_{worker_id}" return basename -def redirect_stdouterr(basename, mode='w'): - '''Redirect stdout and stderr of the current process to files:: +def redirect_stdouterr(basename, mode="w"): + """Redirect stdout and stderr of the current process to files:: - .out - .err - ''' + .out + .err + """ if basename is None: return @@ -53,14 +53,8 @@ def redirect_stdouterr(basename, mode='w'): logdir = basename.parent logdir.mkdir(parents=True, exist_ok=True) - sys.stdout = _file_reopen( - basename.with_suffix('.out'), - mode, - sys.__stdout__) - sys.stderr = _file_reopen( - basename.with_suffix('.err'), - mode, - sys.__stderr__) + sys.stdout = _file_reopen(basename.with_suffix(".out"), mode, sys.__stdout__) + sys.stderr = _file_reopen(basename.with_suffix(".err"), mode, sys.__stderr__) def _file_reopen(filename, mode, file_obj): diff --git a/daisy/messages/message.py b/daisy/messages/message.py index 454bf057..cbc362c8 100644 --- a/daisy/messages/message.py +++ b/daisy/messages/message.py @@ -9,13 +9,13 @@ class Message(TCPMessage): class ExceptionMessage(Message): - '''A message representing an exception. + """A message representing an exception. Args: exception (:class:`Exception`): The exception to wrap into this message. - ''' + """ def __init__(self, exception): self.exception = exception diff --git a/daisy/ready_surface.py b/daisy/ready_surface.py index 928376e0..6fe39289 100644 --- a/daisy/ready_surface.py +++ b/daisy/ready_surface.py @@ -53,9 +53,7 @@ def mark_success(self, node): # check if any downstream nodes need to be added to the boundary for down_node in self.downstream(node): if not self.__add_to_boundary(down_node): - if all( - up_node in self.surface - for up_node in self.upstream(down_node)): + if all(up_node in self.surface for up_node in self.upstream(down_node)): new_ready_nodes.append(down_node) # check if any of the upstream nodes can be removed from surface diff --git a/daisy/roi.py b/daisy/roi.py index 85054d70..c4b04dbc 100644 --- a/daisy/roi.py +++ b/daisy/roi.py @@ -1 +1 @@ -from funlib.geometry import Roi # noqa \ No newline at end of file +from funlib.geometry import Roi # noqa diff --git a/daisy/scheduler.py b/daisy/scheduler.py index b2cf866a..d919a0db 100644 --- a/daisy/scheduler.py +++ b/daisy/scheduler.py @@ -90,19 +90,17 @@ def acquire_block(self, task_id): pre_check_ret = self.__precheck(block) if pre_check_ret: logger.debug( - "Skipping block (%s); already processed.", - block.block_id) + "Skipping block (%s); already processed.", block.block_id + ) block.status = BlockStatus.SUCCESS self.task_states[task_id].skipped_count += 1 # adding block so release_block() can remove it - self.task_queues[task_id].processing_blocks.add( - block.block_id) + self.task_queues[task_id].processing_blocks.add(block.block_id) self.release_block(block) continue else: self.task_states[task_id].started = True - self.task_queues[task_id].processing_blocks.add( - block.block_id) + self.task_queues[task_id].processing_blocks.add(block.block_id) return block else: @@ -158,15 +156,13 @@ def release_block(self, block): else: raise RuntimeError( f"Unexpected status for released block: {block.status} {block}" - ) + ) def __init_task(self, task): if task.task_id not in self.task_map: self.task_map[task.task_id] = task num_blocks = self.dependency_graph.num_blocks(task.task_id) - self.task_states[ - task.task_id - ].total_block_count = num_blocks + self.task_states[task.task_id].total_block_count = num_blocks for upstream_task in task.requires(): self.__init_task(upstream_task) @@ -179,8 +175,7 @@ def __queue_ready_block(self, block, index=None): self.task_states[block.task_id].ready_count += 1 def __remove_from_processing_blocks(self, block): - self.task_queues[block.task_id].processing_blocks.remove( - block.block_id) + self.task_queues[block.task_id].processing_blocks.remove(block.block_id) self.task_states[block.task_id].processing_count -= 1 def __update_ready_queue(self, ready_blocks): @@ -200,6 +195,5 @@ def __precheck(self, block): else: return False except Exception: - logger.exception( - f"pre_check() exception for block {block.block_id}") + logger.exception(f"pre_check() exception for block {block.block_id}") return False diff --git a/daisy/server.py b/daisy/server.py index 8d695cba..11540ec0 100644 --- a/daisy/server.py +++ b/daisy/server.py @@ -8,7 +8,8 @@ ReleaseBlock, SendBlock, RequestShutdown, - UnexpectedMessage) + UnexpectedMessage, +) from .scheduler import Scheduler from .server_observer import ServerObservee from .task_worker_pools import TaskWorkerPools @@ -35,10 +36,7 @@ def __init__(self, stop_event=None): self.tcp_server = TCPServer() self.hostname, self.port = self.tcp_server.address - logger.debug( - "Started server listening at %s:%s", - self.hostname, - self.port) + logger.debug("Started server listening at %s:%s", self.hostname, self.port) def run_blockwise(self, tasks, scheduler=None): @@ -53,10 +51,7 @@ def run_blockwise(self, tasks, scheduler=None): self.finished_tasks = set() self.all_done = False - self.pending_requests = { - task.task_id: Queue() - for task in tasks - } + self.pending_requests = {task.task_id: Queue() for task in tasks} self._recruit_workers() @@ -145,29 +140,27 @@ def _handle_acquire_block(self, message): if task_state.pending_count == 0: logger.debug( - "No more pending blocks for task %s, terminating " - "client", message.task_id) + "No more pending blocks for task %s, terminating " "client", + message.task_id, + ) - self._send_client_message( - message.stream, - RequestShutdown()) + self._send_client_message(message.stream, RequestShutdown()) return # there are more blocks for this task, but none of them has its # dependencies fullfilled logger.debug( - "No currently ready blocks for task %s, delaying " - "request", message.task_id) + "No currently ready blocks for task %s, delaying " "request", + message.task_id, + ) self.pending_requests[message.task_id].put(message) else: try: logger.debug("Sending block %s to client", block) - self._send_client_message( - message.stream, - SendBlock(block)) + self._send_client_message(message.stream, SendBlock(block)) finally: self.block_bookkeeper.notify_block_sent(block, message.stream) @@ -179,8 +172,8 @@ def _handle_release_block(self, message): self._safe_release_block(message.block, message.stream) def _release_block(self, block): - '''Returns a block to the scheduler and checks whether all tasks are - completed.''' + """Returns a block to the scheduler and checks whether all tasks are + completed.""" self.scheduler.release_block(block) task_states = self.scheduler.task_states @@ -192,7 +185,7 @@ def _release_block(self, block): self._recruit_workers() def _check_all_tasks_completed(self): - '''Check if all tasks are completed and stop''' + """Check if all tasks are completed and stop""" self.all_done = True task_states = self.scheduler.task_states @@ -208,18 +201,15 @@ def _check_all_tasks_completed(self): self.all_done = False - logger.debug( - "Task %s has %d ready blocks", - task_id, - task_state.ready_count) + logger.debug("Task %s has %d ready blocks", task_id, task_state.ready_count) if self.all_done: logger.debug("All tasks finished") self.stop_event.set() def _safe_release_block(self, block, stream): - '''Releases a block, if the bookkeeper agrees that this is a valid - return from the given stream.''' + """Releases a block, if the bookkeeper agrees that this is a valid + return from the given stream.""" valid = self.block_bookkeeper.is_valid_return(block, stream) if valid: @@ -227,8 +217,8 @@ def _safe_release_block(self, block, stream): self.block_bookkeeper.notify_block_returned(block, stream) else: logger.debug( - "Attempted to return unexpected block %s from %s", - block, stream) + "Attempted to return unexpected block %s from %s", block, stream + ) def _handle_client_exception(self, message): @@ -237,17 +227,15 @@ def _handle_client_exception(self, message): logger.error( "Block %s failed in worker %s with %s", message.block, - message.context['worker_id'], - repr(message.exception)) + message.context["worker_id"], + repr(message.exception), + ) message.block.status = BlockStatus.FAILED self._safe_release_block(message.block, message.stream) - self.notify_block_failure( - message.block, - message.exception, - message.context) + self.notify_block_failure(message.block, message.exception, message.context) else: raise message.exception @@ -259,16 +247,17 @@ def _recruit_workers(self): for task_id in ready_tasks.keys(): if task_id not in self.started_tasks: - self.notify_task_start( - task_id, - self.scheduler.task_states[task_id]) + self.notify_task_start(task_id, self.scheduler.task_states[task_id]) self.started_tasks.add(task_id) # run the task's callback function - ready_tasks[task_id].init_callback_fn(Context( + ready_tasks[task_id].init_callback_fn( + Context( hostname=self.hostname, port=self.port, task_id=task_id, - worker_id=0)) + worker_id=0, + ) + ) self.worker_pools.recruit_workers(ready_tasks) diff --git a/daisy/task.py b/daisy/task.py index ce4ba857..f4f73758 100644 --- a/daisy/task.py +++ b/daisy/task.py @@ -3,7 +3,7 @@ class Task: - '''Definition of a ``daisy`` task that is to be run in a block-wise + """Definition of a ``daisy`` task that is to be run in a block-wise fashion. Args: @@ -123,7 +123,8 @@ class Task: Time in seconds to wait for a block to be returned from a worker. The worker is killed (and the block retried) if this time is exceeded. - ''' + """ + def __init__( self, task_id, diff --git a/daisy/task_worker_pools.py b/daisy/task_worker_pools.py index cf5596cc..e4eb7ed6 100644 --- a/daisy/task_worker_pools.py +++ b/daisy/task_worker_pools.py @@ -17,9 +17,9 @@ def __init__(self, tasks, server, max_block_failures=3): task.task_id: WorkerPool( task.spawn_worker_function, Context( - hostname=server.hostname, - port=server.port, - task_id=task.task_id)) + hostname=server.hostname, port=server.port, task_id=task.task_id + ), + ) for task in tasks } self.max_block_failures = max_block_failures @@ -32,7 +32,8 @@ def recruit_workers(self, tasks): logger.debug( "Setting number of workers for task %s to %d", task_id, - tasks[task_id].num_workers) + tasks[task_id].num_workers, + ) worker_pool.set_num_workers(tasks[task_id].num_workers) def stop(self): @@ -48,8 +49,8 @@ def check_worker_health(self): def on_block_failure(self, block, exception, context): - task_id = context['task_id'] - worker_id = int(context['worker_id']) + task_id = context["task_id"] + worker_id = int(context["worker_id"]) if task_id not in self.failure_counts: self.failure_counts[task_id] = {} @@ -62,8 +63,8 @@ def on_block_failure(self, block, exception, context): if self.failure_counts[task_id][worker_id] > self.max_block_failures: logger.error( - "Worker %s failed too many times, restarting this worker...", - context) + "Worker %s failed too many times, restarting this worker...", context + ) self.failure_counts[task_id][worker_id] = 0 worker_pool = self.worker_pools[task_id] diff --git a/daisy/tcp/internal_messages.py b/daisy/tcp/internal_messages.py index f067ff35..01dbddf7 100644 --- a/daisy/tcp/internal_messages.py +++ b/daisy/tcp/internal_messages.py @@ -2,22 +2,25 @@ class InternalTCPMessage(TCPMessage): - '''TCP messages used only between :class:`TCPServer` and :class:`TCPClient` + """TCP messages used only between :class:`TCPServer` and :class:`TCPClient` for internal communication. - ''' + """ + pass class NotifyClientDisconnect(InternalTCPMessage): - '''Message to be sent from a :class:`TCPClient` to :class:`TCPServer` to + """Message to be sent from a :class:`TCPClient` to :class:`TCPServer` to initiate a disconnect. - ''' + """ + pass class AckClientDisconnect(InternalTCPMessage): - '''Message to be sent from a :class:`TCPServer` to :class:`TCPClient` to + """Message to be sent from a :class:`TCPServer` to :class:`TCPClient` to confirm a disconnect, i.e., the server will no longer listen to messages received from this client and the associated stream can be closed. - ''' + """ + pass diff --git a/daisy/tcp/io_looper.py b/daisy/tcp/io_looper.py index fe303579..f899d94b 100644 --- a/daisy/tcp/io_looper.py +++ b/daisy/tcp/io_looper.py @@ -7,7 +7,7 @@ class IOLooper: - '''Base class for every class that needs access to tornado's IOLoop in a + """Base class for every class that needs access to tornado's IOLoop in a separate thread. Attributes: @@ -16,7 +16,7 @@ class IOLooper: The IO loop to be used in subclasses. Will run in a singleton thread per process. - ''' + """ threads = {} ioloops = {} @@ -38,8 +38,8 @@ def __init__(self): logger.debug("Starting io loop for process %d...", pid) IOLooper.threads[pid] = threading.Thread( - target=self.ioloop.start, - daemon=True) + target=self.ioloop.start, daemon=True + ) IOLooper.threads[pid].start() else: diff --git a/daisy/tcp/tcp_client.py b/daisy/tcp/tcp_client.py index 77ee3fcf..d58c45eb 100644 --- a/daisy/tcp/tcp_client.py +++ b/daisy/tcp/tcp_client.py @@ -1,9 +1,7 @@ from .exceptions import NotConnected from .io_looper import IOLooper from .tcp_stream import TCPStream -from .internal_messages import ( - AckClientDisconnect, - NotifyClientDisconnect) +from .internal_messages import AckClientDisconnect, NotifyClientDisconnect import logging import queue import time @@ -13,7 +11,7 @@ class TCPClient(IOLooper): - '''A TCP client to handle client-server communication through + """A TCP client to handle client-server communication through :class:`TCPMessage` objects. Args: @@ -22,7 +20,7 @@ class TCPClient(IOLooper): port (int): The hostname and port of the :class:`TCPServer` to connect to. - ''' + """ def __init__(self, host, port): @@ -45,21 +43,21 @@ def __del__(self): self.disconnect() def connect(self): - '''Connect to the server and start the message receive event loop.''' + """Connect to the server and start the message receive event loop.""" logger.debug("Connecting to server at %s:%d...", self.host, self.port) self.ioloop.add_callback(self._connect) while not self.connected(): self._check_for_errors() - time.sleep(.1) + time.sleep(0.1) logger.debug("...connected") self.ioloop.add_callback(self._receive) def disconnect(self): - '''Gracefully close the connection to the server.''' + """Gracefully close the connection to the server.""" if not self.connected(): logger.warn("Called disconnect() on disconnected client") @@ -69,23 +67,23 @@ def disconnect(self): self.stream.send_message(NotifyClientDisconnect()) while self.connected(): - time.sleep(.1) + time.sleep(0.1) logger.debug("Disconnected") def connected(self): - '''Check whether this client has a connection to the server.''' + """Check whether this client has a connection to the server.""" return self.stream is not None def send_message(self, message): - '''Send a message to the server. + """Send a message to the server. Args: message (:class:`TCPMessage`): Message to send over to the server. - ''' + """ self._check_for_errors() @@ -95,7 +93,7 @@ def send_message(self, message): self.stream.send_message(message) def get_message(self, timeout=None): - '''Get a message that was sent to this client. + """Get a message that was sent to this client. Args: @@ -104,7 +102,7 @@ def get_message(self, timeout=None): If set, wait up to `timeout` seconds for a message to arrive. If no message is available after the timeout, returns ``None``. If not set, wait until a message arrived. - ''' + """ self._check_for_errors() @@ -128,7 +126,7 @@ def _check_for_errors(self): return async def _connect(self): - '''Async method to connect to the TCPServer.''' + """Async method to connect to the TCPServer.""" try: stream = await self.client.connect(self.host, self.port) @@ -139,7 +137,7 @@ async def _connect(self): self.stream = TCPStream(stream, (self.host, self.port)) async def _receive(self): - '''Loop that receives messages from the server.''' + """Loop that receives messages from the server.""" logger.debug("Entering receive loop") diff --git a/daisy/tcp/tcp_message.py b/daisy/tcp/tcp_message.py index 818fc1c5..c3f84be6 100644 --- a/daisy/tcp/tcp_message.py +++ b/daisy/tcp/tcp_message.py @@ -1,5 +1,5 @@ class TCPMessage: - '''A message, to be sent between :class:`TCPServer` and :class:`TCPClient`. + """A message, to be sent between :class:`TCPServer` and :class:`TCPClient`. Args: @@ -12,7 +12,8 @@ class TCPMessage: The stream the message was received from. Will be set by :class:`TCPStream` and is ``None`` for messages that have not been sent. - ''' + """ + def __init__(self, payload=None): self.payload = payload diff --git a/daisy/tcp/tcp_server.py b/daisy/tcp/tcp_server.py index ed94046c..b20c757e 100644 --- a/daisy/tcp/tcp_server.py +++ b/daisy/tcp/tcp_server.py @@ -1,7 +1,5 @@ from .exceptions import NoFreePort -from .internal_messages import ( - AckClientDisconnect, - NotifyClientDisconnect) +from .internal_messages import AckClientDisconnect, NotifyClientDisconnect from .io_looper import IOLooper from .tcp_stream import TCPStream import logging @@ -13,14 +11,14 @@ class TCPServer(tornado.tcpserver.TCPServer, IOLooper): - '''A TCP server to handle client-server communication through + """A TCP server to handle client-server communication through :class:`Message` objects. Args: max_port_tries (int, optional): How many times to try to find an empty random port. - ''' + """ def __init__(self, max_port_tries=1000): @@ -43,13 +41,14 @@ def __init__(self, max_port_tries=1000): if i == self.max_port_tries - 1: raise NoFreePort( - "Could not find a free port after %d tries " % - self.max_port_tries) + "Could not find a free port after %d tries " + % self.max_port_tries + ) self.address = self._get_address() def get_message(self, timeout=None): - '''Get a message that was sent to this server. + """Get a message that was sent to this server. If the stream to any of the connected clients is closed, raises a :class:`StreamClosedError` for this client. Other TCP related @@ -62,7 +61,7 @@ def get_message(self, timeout=None): If set, wait up to `timeout` seconds for a message to arrive. If no message is available after the timeout, returns ``None``. If not set, wait until a message arrived. - ''' + """ self._check_for_errors() @@ -75,7 +74,7 @@ def get_message(self, timeout=None): return None def disconnect(self): - '''Close all open streams to clients.''' + """Close all open streams to clients.""" streams = list(self.client_streams) # avoid set change error for stream in streams: @@ -83,7 +82,7 @@ def disconnect(self): stream.close() async def handle_stream(self, stream, address): - ''' Overrides a function from tornado's TCPServer, and is called + """Overrides a function from tornado's TCPServer, and is called whenever there is a new IOStream from an incoming connection (not whenever there is new data in the IOStream). @@ -96,7 +95,7 @@ async def handle_stream(self, stream, address): address (tuple): host, port that new connection comes from - ''' + """ logger.debug("Received new connection from %s:%d", *address) stream = TCPStream(stream, address) @@ -139,7 +138,7 @@ def _check_for_errors(self): return def _get_address(self): - '''Get the host and port of the tcp server''' + """Get the host and port of the tcp server""" sock = self._sockets[list(self._sockets.keys())[0]] port = sock.getsockname()[1] diff --git a/daisy/tcp/tcp_stream.py b/daisy/tcp/tcp_stream.py index 375f2c48..d5d13679 100644 --- a/daisy/tcp/tcp_stream.py +++ b/daisy/tcp/tcp_stream.py @@ -9,7 +9,7 @@ class TCPStream(IOLooper): - '''Wrapper around :class:`tornado.iostream.IOStream` to send + """Wrapper around :class:`tornado.iostream.IOStream` to send :class:`TCPMessage` objects. Args: @@ -19,7 +19,7 @@ class TCPStream(IOLooper): address (tuple): The address the stream originates from. - ''' + """ def __init__(self, stream, address): super().__init__() @@ -27,7 +27,7 @@ def __init__(self, stream, address): self.address = address def send_message(self, message): - '''Send a message through this stream asynchronously. + """Send a message through this stream asynchronously. If the stream is closed, raises a :class:`StreamClosedError`. Successful return of this function does not guarantee that the message @@ -39,7 +39,7 @@ def send_message(self, message): message (:class:`daisy.TCPMessage`): Message to send over the stream. - ''' + """ if self.stream is None: raise StreamClosedError(*self.address) @@ -47,7 +47,7 @@ def send_message(self, message): self.ioloop.add_callback(self._send_message, message) def close(self): - '''Close this stream.''' + """Close this stream.""" try: self.stream.close() except Exception: @@ -56,7 +56,7 @@ def close(self): self.stream = None def closed(self): - '''True if this stream was closed.''' + """True if this stream was closed.""" if self.stream is None: return True return self.stream.closed() @@ -67,7 +67,7 @@ async def _send_message(self, message): logger.error("No TCPStream available, can't send message.") pickled_data = pickle.dumps(message) - message_size_bytes = struct.pack('I', len(pickled_data)) + message_size_bytes = struct.pack("I", len(pickled_data)) message_bytes = message_size_bytes + pickled_data try: @@ -94,8 +94,8 @@ async def _get_message(self): try: size = await self.stream.read_bytes(4) - size = struct.unpack('I', size)[0] - assert (size < 65535) # TODO: parameterize max message size + size = struct.unpack("I", size)[0] + assert size < 65535 # TODO: parameterize max message size pickled_data = await self.stream.read_bytes(size) except tornado.iostream.StreamClosedError: diff --git a/daisy/worker.py b/daisy/worker.py index 6c5cfa0f..67d6d12e 100644 --- a/daisy/worker.py +++ b/daisy/worker.py @@ -8,8 +8,8 @@ logger = logging.getLogger(__name__) -class Worker(): - '''Create and start a worker, running a user-specified function in its own +class Worker: + """Create and start a worker, running a user-specified function in its own process. Args: @@ -24,9 +24,9 @@ class Worker(): If given, the context will be passed on to the worker via environment variables. - ''' + """ - __next_id = multiprocessing.Value('L') + __next_id = multiprocessing.Value("L") @staticmethod def get_next_id(): @@ -43,25 +43,24 @@ def __init__(self, spawn_function, context=None, error_queue=None): self.context = Context() else: self.context = context.copy() - self.context['worker_id'] = self.worker_id + self.context["worker_id"] = self.worker_id self.error_queue = error_queue self.process = None self.start() def start(self): - '''Start this worker. Note that workers are automatically started when - created. Use this function to re-start a stopped worker.''' + """Start this worker. Note that workers are automatically started when + created. Use this function to re-start a stopped worker.""" if self.process is not None: return - self.process = multiprocessing.Process( - target=lambda: self.__spawn_wrapper()) + self.process = multiprocessing.Process(target=lambda: self.__spawn_wrapper()) self.process.start() def stop(self): - '''Stop this worker.''' + """Stop this worker.""" if self.process is None: return @@ -76,16 +75,16 @@ def stop(self): self.process = None def __spawn_wrapper(self): - '''Thin wrapper around the user-specified spawn function to set - environment variables, redirect output, and to capture exceptions.''' + """Thin wrapper around the user-specified spawn function to set + environment variables, redirect output, and to capture exceptions.""" try: os.environ[self.context.ENV_VARIABLE] = self.context.to_env() log_base = daisy_logging.get_worker_log_basename( - self.worker_id, - self.context.get('task_id', None)) + self.worker_id, self.context.get("task_id", None) + ) daisy_logging.redirect_stdouterr(log_base) self.spawn_function() @@ -98,8 +97,8 @@ def __spawn_wrapper(self): self.error_queue.put(e, timeout=1) except queue.Full: logger.error( - "%s failed to forward exception, error queue is full", - self) + "%s failed to forward exception, error queue is full", self + ) except KeyboardInterrupt: diff --git a/daisy/worker_pool.py b/daisy/worker_pool.py index 5919cf8e..dfc466de 100644 --- a/daisy/worker_pool.py +++ b/daisy/worker_pool.py @@ -9,7 +9,7 @@ class WorkerPool: - '''Manages a pool of workers in individual processes. All workers are + """Manages a pool of workers in individual processes. All workers are spawned by the same user-specified function. Args: @@ -27,7 +27,7 @@ class WorkerPool: A context to pass on to workers through environment variables. Will be augmented with ``worker_id``, a unique ID for each worker that is spawned by this pool. - ''' + """ def __init__(self, spawn_worker_function, context=None): @@ -42,7 +42,7 @@ def __init__(self, spawn_worker_function, context=None): self.error_queue = multiprocessing.Queue(100) def set_num_workers(self, num_workers): - '''Set the number of workers in this pool. + """Set the number of workers in this pool. If higher than the current number of running workers, new workers will be spawned using ``spawn_worker_function``. @@ -56,7 +56,7 @@ def set_num_workers(self, num_workers): num_workers (int): The new number of workers for this pool. - ''' + """ logger.debug("setting number of workers to %d", num_workers) @@ -75,8 +75,8 @@ def inc_num_workers(self, num_workers): self.__start_workers(num_workers) def stop(self, worker_id=None): - '''Stop all current workers in this pool (``worker_id == None``) or a - specific worker.''' + """Stop all current workers in this pool (``worker_id == None``) or a + specific worker.""" if worker_id is None: self.set_num_workers(0) @@ -91,11 +91,11 @@ def stop(self, worker_id=None): del self.workers[worker_id] def check_for_errors(self): - '''If a worker fails with an exception, this exception will be queued + """If a worker fails with an exception, this exception will be queued in this pool to be propagated to the process that created the pool. Call this function periodically to check the queue and raise exceptions coming from the workers in the calling process. - ''' + """ try: error = self.error_queue.get(block=False) @@ -111,10 +111,7 @@ def __start_workers(self, n): Worker(self.spawn_function, self.context, self.error_queue) for _ in range(n) ] - self.workers.update({ - worker.worker_id: worker - for worker in new_workers - }) + self.workers.update({worker.worker_id: worker for worker in new_workers}) def __stop_workers(self, n): diff --git a/docs/conf.py b/docs/conf.py index 59459d62..527a8b60 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,14 +19,14 @@ # -- Project information ----------------------------------------------------- -project = 'daisy' -copyright = '2019, Jan Funke, Tri Nguyen, Carolin Malin-Mayor, Arlo Sheridan, Philipp Hanslovsky, Chris Barnes' -author = 'Jan Funke, Tri Nguyen, Carolin Malin-Mayor, Arlo Sheridan, Philipp Hanslovsky, Chris Barnes' +project = "daisy" +copyright = "2019, Jan Funke, Tri Nguyen, Carolin Malin-Mayor, Arlo Sheridan, Philipp Hanslovsky, Chris Barnes" +author = "Jan Funke, Tri Nguyen, Carolin Malin-Mayor, Arlo Sheridan, Philipp Hanslovsky, Chris Barnes" # The short X.Y version -version = '' +version = "" # The full version, including alpha/beta/rc tags -release = 'v0.2' +release = "v0.2" # -- General configuration --------------------------------------------------- @@ -39,22 +39,22 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.mathjax', - 'sphinx.ext.viewcode', + "sphinx.ext.autodoc", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -66,7 +66,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. pygments_style = None @@ -77,7 +77,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -88,7 +88,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -104,7 +104,7 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'daisydoc' +htmlhelp_basename = "daisydoc" # -- Options for LaTeX output ------------------------------------------------ @@ -113,15 +113,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -131,8 +128,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'daisy.tex', 'daisy Documentation', - 'Jan Funke, Tri Nguyen, Carolin Malin-Mayor, Arlo Sheridan, Philipp Hanslovsky, Chris Barnes', 'manual'), + ( + master_doc, + "daisy.tex", + "daisy Documentation", + "Jan Funke, Tri Nguyen, Carolin Malin-Mayor, Arlo Sheridan, Philipp Hanslovsky, Chris Barnes", + "manual", + ), ] @@ -140,10 +142,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'daisy', 'daisy Documentation', - [author], 1) -] +man_pages = [(master_doc, "daisy", "daisy Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -152,9 +151,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'daisy', 'daisy Documentation', - author, 'daisy', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "daisy", + "daisy Documentation", + author, + "daisy", + "One line description of project.", + "Miscellaneous", + ), ] @@ -173,7 +178,7 @@ # epub_uid = '' # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # -- Extension configuration ------------------------------------------------- diff --git a/examples/basic_workflow.py b/examples/basic_workflow.py index 0958c31e..e659a2ec 100644 --- a/examples/basic_workflow.py +++ b/examples/basic_workflow.py @@ -11,29 +11,45 @@ def process_function(b): return 0 -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--total_roi_size', '-t', nargs='+', - help="Size of total region to process", - default=[100, 100]) - parser.add_argument('--block_read_size', '-r', nargs='+', - help="Size of block read region", - default=[20, 20]) - parser.add_argument('--block_write_size', '-w', nargs='+', - help="Size of block write region", - default=[16, 16]) - parser.add_argument('--num_workers', '-nw', type=int, - help="Number of processes to spawn", - default=1) - parser.add_argument('--read_write_conflict', '-rwc', action='store_true', - help="Flag to not schedule overlapping blocks" - " at the same time. Default is false") + parser.add_argument( + "--total_roi_size", + "-t", + nargs="+", + help="Size of total region to process", + default=[100, 100], + ) + parser.add_argument( + "--block_read_size", + "-r", + nargs="+", + help="Size of block read region", + default=[20, 20], + ) + parser.add_argument( + "--block_write_size", + "-w", + nargs="+", + help="Size of block write region", + default=[16, 16], + ) + parser.add_argument( + "--num_workers", "-nw", type=int, help="Number of processes to spawn", default=1 + ) + parser.add_argument( + "--read_write_conflict", + "-rwc", + action="store_true", + help="Flag to not schedule overlapping blocks" + " at the same time. Default is false", + ) args = parser.parse_args() ndims = len(args.total_roi_size) # define total region of interest (roi) - total_roi_start = daisy.Coordinate((0,)*ndims) + total_roi_start = daisy.Coordinate((0,) * ndims) total_roi_size = daisy.Coordinate(args.total_roi_size) total_roi = daisy.Roi(total_roi_start, total_roi_size) @@ -46,9 +62,10 @@ def process_function(b): # call run_blockwise daisy.run_blockwise( - total_roi, - block_read_roi, - block_write_roi, - process_function=process_function, - read_write_conflict=args.read_write_conflict, - num_workers=args.num_workers) + total_roi, + block_read_roi, + block_write_roi, + process_function=process_function, + read_write_conflict=args.read_write_conflict, + num_workers=args.num_workers, + ) diff --git a/examples/batch_task.py b/examples/batch_task.py index 9c80d14d..555e8aae 100644 --- a/examples/batch_task.py +++ b/examples/batch_task.py @@ -11,15 +11,16 @@ import daisy daisy_version = float(pkg_resources.get_distribution("daisy").version) -assert daisy_version >= 1, \ - f"This script was written for daisy v1.0 but current installed version " \ +assert daisy_version >= 1, ( + f"This script was written for daisy v1.0 but current installed version " f"is {daisy_version}" +) logging.basicConfig(level=logging.INFO) logger = logging.getLogger("BatchTask") -class Database(): +class Database: def __init__(self, db_host, db_id, overwrite=False): self.table_name = "completion_db_col" @@ -28,7 +29,7 @@ def __init__(self, db_host, db_id, overwrite=False): # Use SQLite self.use_sql = True os.makedirs("daisy_db", exist_ok=True) - self.con = sqlite3.connect(f'daisy_db/{db_id}.db', check_same_thread=False) + self.con = sqlite3.connect(f"daisy_db/{db_id}.db", check_same_thread=False) self.cur = self.con.cursor() if overwrite: @@ -54,21 +55,22 @@ def __init__(self, db_host, db_id, overwrite=False): if self.table_name not in db.list_collection_names(): self.completion_db = db[self.table_name] self.completion_db.create_index( - [('block_id', pymongo.ASCENDING)], - name='block_id') + [("block_id", pymongo.ASCENDING)], name="block_id" + ) else: self.completion_db = db[self.table_name] def check(self, block_id): if self.use_sql: - block_id = '_'.join([str(s) for s in block_id]) + block_id = "_".join([str(s) for s in block_id]) res = self.cur.execute( - f"SELECT * FROM {self.table_name} where block_id = '{block_id}'").fetchall() + f"SELECT * FROM {self.table_name} where block_id = '{block_id}'" + ).fetchall() if len(res): return True else: - if self.completion_db.count_documents({'block_id': block_id}) >= 1: + if self.completion_db.count_documents({"block_id": block_id}) >= 1: return True return False @@ -76,53 +78,68 @@ def check(self, block_id): def add_finished(self, block_id): if self.use_sql: - block_id = '_'.join([str(s) for s in block_id]) + block_id = "_".join([str(s) for s in block_id]) self.cur.execute(f"INSERT INTO {self.table_name} VALUES ('{block_id}')") self.con.commit() else: - document = { - 'block_id': block_id - } + document = {"block_id": block_id} self.completion_db.insert_one(document) -class BatchTask(): - '''Example base class for a batchable Daisy task. +class BatchTask: + """Example base class for a batchable Daisy task. This class takes care of some plumbing work such as creating a database to keep track of finished block and writing a config and batch file that users can submit jobs to a job system like SLURM. Derived tasks will only need to implement the code for computing the task. - ''' + """ @staticmethod def parse_args(ap): try: ap.add_argument( - "--db_host", type=str, - help='MongoDB database host name. If `None` (default), use SQLite', - default=None) - # default='10.117.28.139') + "--db_host", + type=str, + help="MongoDB database host name. If `None` (default), use SQLite", + default=None, + ) + # default='10.117.28.139') ap.add_argument( - "--db_name", type=str, help='MongoDB database project name', - default=None) + "--db_name", + type=str, + help="MongoDB database project name", + default=None, + ) ap.add_argument( - "--overwrite", type=int, help='Whether to overwrite completed blocks', - default=0) + "--overwrite", + type=int, + help="Whether to overwrite completed blocks", + default=0, + ) ap.add_argument( - "--num_workers", type=int, help='Number of workers to run', - default=4) + "--num_workers", type=int, help="Number of workers to run", default=4 + ) ap.add_argument( - "--no_launch_workers", type=int, help='Whether to run workers automatically', - default=0) + "--no_launch_workers", + type=int, + help="Whether to run workers automatically", + default=0, + ) ap.add_argument( - "--config_hash", type=str, help='config string, used to keep track of progress', - default=None) + "--config_hash", + type=str, + help="config string, used to keep track of progress", + default=None, + ) ap.add_argument( - "--task_name", type=str, help='Name of task, default to class name', - default=None) + "--task_name", + type=str, + help="Name of task, default to class name", + default=None, + ) except argparse.ArgumentError as e: print("Current task has conflicting arguments with BatchTask!") raise e @@ -133,13 +150,13 @@ def __init__(self, config=None, config_file=None, task_id=None): if config_file: print(f"Loading from config_file: {config_file}") - with open(config_file, 'r') as f: + with open(config_file, "r") as f: config = json.load(f) assert config is not None for key in config: - setattr(self, '%s' % key, config[key]) + setattr(self, "%s" % key, config[key]) if self.task_name is None: self.task_name = str(self.__class__.__name__) @@ -147,13 +164,18 @@ def __init__(self, config=None, config_file=None, task_id=None): self.__init_config = copy.deepcopy(config) if self.config_hash is None: - config_str = ''.join(['%s' % (v,) for k,v in config.items() - if k not in ['overwrite', 'num_workers', 'no_launch_workers']]) + config_str = "".join( + [ + "%s" % (v,) + for k, v in config.items() + if k not in ["overwrite", "num_workers", "no_launch_workers"] + ] + ) self.config_hash = str(hashlib.md5(config_str.encode()).hexdigest()) config_hash_short = self.config_hash[0:8] - self.db_id = '%s_%s' % (self.task_name, self.config_hash) - db_id_short = '%s_%s' % (self.task_name, config_hash_short) + self.db_id = "%s_%s" % (self.task_name, self.config_hash) + db_id_short = "%s_%s" % (self.task_name, config_hash_short) # if not given, we need to give the task a unique id for chaining if task_id is None: @@ -168,67 +190,69 @@ def _task_init(self, config): assert False, "Function needs to be implemented by subclass" def prepare_task(self): - '''Called by user to get a `daisy.Task`. It should call + """Called by user to get a `daisy.Task`. It should call `_write_config()` and return with a call to `_prepare_task()` Returns: `daisy.Task` object - ''' + """ assert False, "Function needs to be implemented by subclass" def _write_config(self, worker_filename, extra_config=None): - '''Make a config file for workers. Workers can then be run on the + """Make a config file for workers. Workers can then be run on the command line on potentially a different machine and use this file to initialize its variables. Args: extra_config (``dict``, optional): Any extra configs that should be written for workers - ''' + """ config = self.__init_config if extra_config: for k in extra_config: config[k] = extra_config[k] - self.config_file = os.path.join( - '.run_configs', '%s.config' % self.db_id) + self.config_file = os.path.join(".run_configs", "%s.config" % self.db_id) - self.new_actor_cmd = 'python %s run_worker %s' % ( - worker_filename, self.config_file) + self.new_actor_cmd = "python %s run_worker %s" % ( + worker_filename, + self.config_file, + ) if self.db_name is None: - self.db_name = '%s' % self.db_id + self.db_name = "%s" % self.db_id - config['db_id'] = self.db_id + config["db_id"] = self.db_id - os.makedirs('.run_configs', exist_ok=True) - with open(self.config_file, 'w') as f: + os.makedirs(".run_configs", exist_ok=True) + with open(self.config_file, "w") as f: json.dump(config, f) # write batch script - self.sbatch_file = os.path.join('.run_configs', '%s.sh' % self.db_id) + self.sbatch_file = os.path.join(".run_configs", "%s.sh" % self.db_id) self.generate_batch_script( self.sbatch_file, self.new_actor_cmd, - log_dir='.logs', + log_dir=".logs", logname=self.db_id, - ) + ) self.write_config_called = True def generate_batch_script( - self, - output_script, - run_cmd, - log_dir, - logname, - cpu_time=11, - queue='short', - cpu_cores=1, - cpu_mem=2, - gpu=None): - '''Example SLURM script.''' + self, + output_script, + run_cmd, + log_dir, + logname, + cpu_time=11, + queue="short", + cpu_cores=1, + cpu_mem=2, + gpu=None, + ): + """Example SLURM script.""" text = [] text.append("#!/bin/bash") @@ -236,7 +260,7 @@ def generate_batch_script( if gpu is not None: text.append("#SBATCH -p gpu") - if gpu == '' or gpu == 'any': + if gpu == "" or gpu == "any": text.append("#SBATCH --gres=gpu:1") else: text.append("#SBATCH --gres=gpu:{}:1".format(gpu)) @@ -250,25 +274,26 @@ def generate_batch_script( text.append("") text.append(run_cmd) - with open(output_script, 'w') as f: - f.write('\n'.join(text)) + with open(output_script, "w") as f: + f.write("\n".join(text)) def _prepare_task( - self, - total_roi, - read_roi, - write_roi, - check_fn=None, - fit='shrink', - read_write_conflict=False, - upstream_tasks=None, - ): - - assert self.write_config_called, ( - "`BatchTask._write_config()` was not called") - - print("Processing total_roi %s with read_roi %s and write_roi %s" % ( - total_roi, read_roi, write_roi)) + self, + total_roi, + read_roi, + write_roi, + check_fn=None, + fit="shrink", + read_write_conflict=False, + upstream_tasks=None, + ): + + assert self.write_config_called, "`BatchTask._write_config()` was not called" + + print( + "Processing total_roi %s with read_roi %s and write_roi %s" + % (total_roi, read_roi, write_roi) + ) if check_fn is None: check_fn = lambda b: self._default_check_fn(b) @@ -305,7 +330,7 @@ def _prepare_task( ) def _default_check_fn(self, block): - '''The default check function uses database for checking completion''' + """The default check function uses database for checking completion""" if self.overwrite: return False @@ -313,7 +338,7 @@ def _default_check_fn(self, block): return self.database.check(block.block_id) def _worker_impl(self, args): - '''Worker function implementation''' + """Worker function implementation""" assert False, "Function needs to be implemented by subclass" def _new_worker(self): @@ -322,11 +347,12 @@ def _new_worker(self): self.run_worker() def run_worker(self): - '''Wrapper for `_worker_impl()`''' + """Wrapper for `_worker_impl()`""" - assert 'DAISY_CONTEXT' in os.environ, ( - "DAISY_CONTEXT must be defined as an environment variable") - logger.info("WORKER: Running with context %s" % os.environ['DAISY_CONTEXT']) + assert ( + "DAISY_CONTEXT" in os.environ + ), "DAISY_CONTEXT must be defined as an environment variable" + logger.info("WORKER: Running with context %s" % os.environ["DAISY_CONTEXT"]) database = Database(self.db_host, self.db_id) client_scheduler = daisy.Client() @@ -335,14 +361,20 @@ def run_worker(self): with client_scheduler.acquire_block() as block: if block is None: break - logger.info(f'Received block {block}') + logger.info(f"Received block {block}") self._worker_impl(block) database.add_finished(block.block_id) def init_callback_fn(self, context): - print("sbatch command: DAISY_CONTEXT={} sbatch --parsable {}\n".format( - context.to_env(), self.sbatch_file)) + print( + "sbatch command: DAISY_CONTEXT={} sbatch --parsable {}\n".format( + context.to_env(), self.sbatch_file + ) + ) - print("Terminal command: DAISY_CONTEXT={} {}\n".format( - context.to_env(), self.new_actor_cmd)) + print( + "Terminal command: DAISY_CONTEXT={} {}\n".format( + context.to_env(), self.new_actor_cmd + ) + ) diff --git a/examples/chaining_example.py b/examples/chaining_example.py index 1c8abaa6..35f22d3a 100644 --- a/examples/chaining_example.py +++ b/examples/chaining_example.py @@ -5,37 +5,49 @@ import daisy from gaussian_smoothing2 import GaussianSmoothingTask -if __name__ == '__main__': +if __name__ == "__main__": ap = argparse.ArgumentParser() - ap.add_argument("in_file", type=str, help='The input container') - ap.add_argument("in_ds_name", type=str, help='The name of the dataset') - ap.add_argument("--out_file", type=str, default=None, - help='The output container, defaults to be the same as in_file') - ap.add_argument('--sigma', '-s', type=float, - help="Sigma to use for gaussian filter", - default=2) - ap.add_argument('--block_read_size', '-r', nargs='+', - help="Size of block read region", - default=[20, 200, 200]) - ap.add_argument('--block_write_size', '-w', nargs='+', - help="Size of block write region", - default=[18, 180, 180]) + ap.add_argument("in_file", type=str, help="The input container") + ap.add_argument("in_ds_name", type=str, help="The name of the dataset") + ap.add_argument( + "--out_file", + type=str, + default=None, + help="The output container, defaults to be the same as in_file", + ) + ap.add_argument( + "--sigma", "-s", type=float, help="Sigma to use for gaussian filter", default=2 + ) + ap.add_argument( + "--block_read_size", + "-r", + nargs="+", + help="Size of block read region", + default=[20, 200, 200], + ) + ap.add_argument( + "--block_write_size", + "-w", + nargs="+", + help="Size of block write region", + default=[18, 180, 180], + ) config = GaussianSmoothingTask.parse_args(ap) config1 = copy.deepcopy(config) - config1['out_ds_name'] = 'volumes/raw_smoothed' - daisy_task1 = GaussianSmoothingTask( - config1, task_id='Gaussian1').prepare_task() + config1["out_ds_name"] = "volumes/raw_smoothed" + daisy_task1 = GaussianSmoothingTask(config1, task_id="Gaussian1").prepare_task() # here we reuse parameters but set the output dataset of the previous # task as input config2 = copy.deepcopy(config) - config2['in_ds_name'] = 'volumes/raw_smoothed' - config2['out_ds_name'] = 'volumes/raw_smoothed_smoothed' - daisy_task2 = GaussianSmoothingTask( - config2, task_id='Gaussian2').prepare_task(upstream_tasks=[daisy_task1]) + config2["in_ds_name"] = "volumes/raw_smoothed" + config2["out_ds_name"] = "volumes/raw_smoothed_smoothed" + daisy_task2 = GaussianSmoothingTask(config2, task_id="Gaussian2").prepare_task( + upstream_tasks=[daisy_task1] + ) done = daisy.run_blockwise([daisy_task1, daisy_task2]) diff --git a/examples/gaussian_smoothing1.py b/examples/gaussian_smoothing1.py index 21dea4f1..5ee8e1e6 100644 --- a/examples/gaussian_smoothing1.py +++ b/examples/gaussian_smoothing1.py @@ -9,8 +9,9 @@ from batch_task import BatchTask import logging + logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('GaussianSmoothingTask') +logger = logging.getLogger("GaussianSmoothingTask") def smooth(block, dataset, output, sigma=5): @@ -22,39 +23,52 @@ def smooth(block, dataset, output, sigma=5): logger.debug("Got data of shape %s" % str(data.shape)) # apply gaussian filter - r = scipy.ndimage.gaussian_filter( - data, sigma=sigma, mode='constant') + r = scipy.ndimage.gaussian_filter(data, sigma=sigma, mode="constant") # write result to output dataset in block.write_roi - to_write = daisy.Array( - data=r, - roi=block.read_roi, - voxel_size=dataset.voxel_size) + to_write = daisy.Array(data=r, roi=block.read_roi, voxel_size=dataset.voxel_size) output[block.write_roi] = to_write[block.write_roi] logger.debug("Done") return 0 -if __name__ == '__main__': +if __name__ == "__main__": ap = argparse.ArgumentParser() - ap.add_argument("in_file", type=str, help='The input container') - ap.add_argument("in_ds_name", type=str, help='The name of the dataset') - ap.add_argument("--out_file", type=str, default=None, - help='The output container, defaults to be the same as in_file') - ap.add_argument("--out_ds_name", type=str, default=None, - help='The name of the dataset, defaults to be in_ds_name + smoothed') - ap.add_argument('--sigma', '-s', type=float, - help="Sigma to use for gaussian filter", - default=2) - ap.add_argument('--block_read_size', '-r', nargs='+', - help="Size of block read region", - default=[20, 200, 200]) - ap.add_argument('--block_write_size', '-w', nargs='+', - help="Size of block write region", - default=[18, 180, 180]) - ap.add_argument("--num_workers", type=int, help='Number of workers to run', - default=4) + ap.add_argument("in_file", type=str, help="The input container") + ap.add_argument("in_ds_name", type=str, help="The name of the dataset") + ap.add_argument( + "--out_file", + type=str, + default=None, + help="The output container, defaults to be the same as in_file", + ) + ap.add_argument( + "--out_ds_name", + type=str, + default=None, + help="The name of the dataset, defaults to be in_ds_name + smoothed", + ) + ap.add_argument( + "--sigma", "-s", type=float, help="Sigma to use for gaussian filter", default=2 + ) + ap.add_argument( + "--block_read_size", + "-r", + nargs="+", + help="Size of block read region", + default=[20, 200, 200], + ) + ap.add_argument( + "--block_write_size", + "-w", + nargs="+", + help="Size of block write region", + default=[18, 180, 180], + ) + ap.add_argument( + "--num_workers", type=int, help="Number of workers to run", default=4 + ) config = ap.parse_args() @@ -66,16 +80,18 @@ def smooth(block, dataset, output, sigma=5): ndims = len(total_roi.get_offset()) # define block read and write rois - assert len(config.block_read_size) == ndims,\ - "Read size must have same dimensions as in_file" - assert len(config.block_write_size) == ndims,\ - "Write size must have same dimensions as in_file" + assert ( + len(config.block_read_size) == ndims + ), "Read size must have same dimensions as in_file" + assert ( + len(config.block_write_size) == ndims + ), "Write size must have same dimensions as in_file" block_read_size = daisy.Coordinate(config.block_read_size) block_write_size = daisy.Coordinate(config.block_write_size) block_read_size *= dataset.voxel_size block_write_size *= dataset.voxel_size context = (block_read_size - block_write_size) / 2 - block_read_roi = daisy.Roi((0,)*ndims, block_read_size) + block_read_roi = daisy.Roi((0,) * ndims, block_read_size) block_write_roi = daisy.Roi(context, block_write_size) # prepare output dataset @@ -83,30 +99,32 @@ def smooth(block, dataset, output, sigma=5): if config.out_file is None: config.out_file = config.in_file if config.out_ds_name is None: - config.out_ds_name = config.in_ds_name + '_smoothed' + config.out_ds_name = config.in_ds_name + "_smoothed" - logger.info(f'Processing data to {config.out_file}/{config.out_ds_name}') + logger.info(f"Processing data to {config.out_file}/{config.out_ds_name}") output_dataset = daisy.prepare_ds( - config.out_file, - config.out_ds_name, - total_roi=output_roi, - voxel_size=dataset.voxel_size, - dtype=dataset.dtype, - write_size=block_write_roi.get_shape()) + config.out_file, + config.out_ds_name, + total_roi=output_roi, + voxel_size=dataset.voxel_size, + dtype=dataset.dtype, + write_size=block_write_roi.get_shape(), + ) # make task task = daisy.Task( - 'GaussianSmoothingTask', - total_roi, - block_read_roi, - block_write_roi, - process_function=lambda b: smooth( - b, dataset, output_dataset, sigma=config.sigma), - read_write_conflict=False, - num_workers=config.num_workers, - fit='shrink' - ) + "GaussianSmoothingTask", + total_roi, + block_read_roi, + block_write_roi, + process_function=lambda b: smooth( + b, dataset, output_dataset, sigma=config.sigma + ), + read_write_conflict=False, + num_workers=config.num_workers, + fit="shrink", + ) # run task ret = daisy.run_blockwise([task]) @@ -114,4 +132,4 @@ def smooth(block, dataset, output, sigma=5): if ret: logger.info("Ran all blocks successfully!") else: - logger.info("Did not run all blocks successfully...") \ No newline at end of file + logger.info("Did not run all blocks successfully...") diff --git a/examples/gaussian_smoothing2.py b/examples/gaussian_smoothing2.py index 5f42583c..5a5ba893 100644 --- a/examples/gaussian_smoothing2.py +++ b/examples/gaussian_smoothing2.py @@ -9,8 +9,9 @@ from batch_task import BatchTask import logging + logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('GaussianSmoothingTask') +logger = logging.getLogger("GaussianSmoothingTask") def smooth(block, dataset, output, sigma=5): @@ -22,14 +23,10 @@ def smooth(block, dataset, output, sigma=5): logger.debug("Got data of shape %s" % str(data.shape)) # apply gaussian filter - r = scipy.ndimage.gaussian_filter( - data, sigma=sigma, mode='constant') + r = scipy.ndimage.gaussian_filter(data, sigma=sigma, mode="constant") # write result to output dataset in block.write_roi - to_write = daisy.Array( - data=r, - roi=block.read_roi, - voxel_size=dataset.voxel_size) + to_write = daisy.Array(data=r, roi=block.read_roi, voxel_size=dataset.voxel_size) output[block.write_roi] = to_write[block.write_roi] logger.debug("Done") return 0 @@ -47,16 +44,18 @@ def _task_init(self): ndims = len(total_roi.get_offset()) # define block read and write rois - assert len(self.block_read_size) == ndims,\ - "Read size must have same dimensions as in_file" - assert len(self.block_write_size) == ndims,\ - "Write size must have same dimensions as in_file" + assert ( + len(self.block_read_size) == ndims + ), "Read size must have same dimensions as in_file" + assert ( + len(self.block_write_size) == ndims + ), "Write size must have same dimensions as in_file" block_read_size = daisy.Coordinate(self.block_read_size) block_write_size = daisy.Coordinate(self.block_write_size) block_read_size *= dataset.voxel_size block_write_size *= dataset.voxel_size context = (block_read_size - block_write_size) / 2 - block_read_roi = daisy.Roi((0,)*ndims, block_read_size) + block_read_roi = daisy.Roi((0,) * ndims, block_read_size) block_write_roi = daisy.Roi(context, block_write_size) # prepare output dataset @@ -64,17 +63,18 @@ def _task_init(self): if self.out_file is None: self.out_file = self.in_file if self.out_ds_name is None: - self.out_ds_name = self.in_ds_name + '_smoothed' + self.out_ds_name = self.in_ds_name + "_smoothed" - logger.info(f'Processing data to {self.out_file}/{self.out_ds_name}') + logger.info(f"Processing data to {self.out_file}/{self.out_ds_name}") output_dataset = daisy.prepare_ds( - self.out_file, - self.out_ds_name, - total_roi=output_roi, - voxel_size=dataset.voxel_size, - dtype=dataset.dtype, - write_size=block_write_roi.get_shape()) + self.out_file, + self.out_ds_name, + total_roi=output_roi, + voxel_size=dataset.voxel_size, + dtype=dataset.dtype, + write_size=block_write_roi.get_shape(), + ) # save variables for other functions self.total_roi = total_roi @@ -88,42 +88,62 @@ def prepare_task(self, upstream_tasks=None): worker_filename = os.path.realpath(__file__) self._write_config(worker_filename) return self._prepare_task( - self.total_roi, - self.block_read_roi, - self.block_write_roi, - read_write_conflict=False, - upstream_tasks=upstream_tasks, - ) + self.total_roi, + self.block_read_roi, + self.block_write_roi, + read_write_conflict=False, + upstream_tasks=upstream_tasks, + ) def _worker_impl(self, block): - '''Worker function implementation''' + """Worker function implementation""" smooth(block, self.dataset, self.output_dataset, sigma=self.sigma) -if __name__ == '__main__': +if __name__ == "__main__": - if len(sys.argv) > 1 and sys.argv[1] == 'run_worker': + if len(sys.argv) > 1 and sys.argv[1] == "run_worker": task = GaussianSmoothingTask(config_file=sys.argv[2]) task.run_worker() else: ap = argparse.ArgumentParser() - ap.add_argument("in_file", type=str, help='The input container') - ap.add_argument("in_ds_name", type=str, help='The name of the dataset') - ap.add_argument("--out_file", type=str, default=None, - help='The output container, defaults to be the same as in_file') - ap.add_argument("--out_ds_name", type=str, default=None, - help='The name of the dataset, defaults to be in_ds_name + smoothed') - ap.add_argument('--sigma', '-s', type=float, - help="Sigma to use for gaussian filter", - default=2) - ap.add_argument('--block_read_size', '-r', nargs='+', - help="Size of block read region", - default=[20, 200, 200]) - ap.add_argument('--block_write_size', '-w', nargs='+', - help="Size of block write region", - default=[18, 180, 180]) + ap.add_argument("in_file", type=str, help="The input container") + ap.add_argument("in_ds_name", type=str, help="The name of the dataset") + ap.add_argument( + "--out_file", + type=str, + default=None, + help="The output container, defaults to be the same as in_file", + ) + ap.add_argument( + "--out_ds_name", + type=str, + default=None, + help="The name of the dataset, defaults to be in_ds_name + smoothed", + ) + ap.add_argument( + "--sigma", + "-s", + type=float, + help="Sigma to use for gaussian filter", + default=2, + ) + ap.add_argument( + "--block_read_size", + "-r", + nargs="+", + help="Size of block read region", + default=[20, 200, 200], + ) + ap.add_argument( + "--block_write_size", + "-w", + nargs="+", + help="Size of block write region", + default=[18, 180, 180], + ) config = GaussianSmoothingTask.parse_args(ap) task = GaussianSmoothingTask(config) @@ -132,4 +152,4 @@ def _worker_impl(self, block): if done: logger.info("Ran all blocks successfully!") else: - logger.info("Did not run all blocks successfully...") \ No newline at end of file + logger.info("Did not run all blocks successfully...") diff --git a/examples/hdf_to_zarr.py b/examples/hdf_to_zarr.py index 36c12e77..83ce0d89 100644 --- a/examples/hdf_to_zarr.py +++ b/examples/hdf_to_zarr.py @@ -51,35 +51,30 @@ def _task_init(self): elif self.in_ds.n_channel_dims == 1: num_channels = self.in_ds.shape[0] else: - raise RuntimeError( - "more than one channel not yet implemented, sorry...") + raise RuntimeError("more than one channel not yet implemented, sorry...") self.ds_roi = self.in_ds.roi sub_roi = None if self.roi_offset is not None or self.roi_shape is not None: assert self.roi_offset is not None and self.roi_shape is not None - self.schedule_roi = daisy.Roi( - tuple(self.roi_offset), tuple(self.roi_shape)) + self.schedule_roi = daisy.Roi(tuple(self.roi_offset), tuple(self.roi_shape)) sub_roi = self.schedule_roi else: self.schedule_roi = self.in_ds.roi if self.chunk_shape_voxel is None: self.chunk_shape_voxel = calculateNearIsotropicDimensions( - voxel_size, self.max_voxel_count) + voxel_size, self.max_voxel_count + ) logger.info(voxel_size) logger.info(self.chunk_shape_voxel) self.chunk_shape_voxel = Coordinate(self.chunk_shape_voxel) - self.schedule_roi = self.schedule_roi.snap_to_grid( - voxel_size, - mode='grow') - out_ds_roi = self.ds_roi.snap_to_grid( - voxel_size, - mode='grow') + self.schedule_roi = self.schedule_roi.snap_to_grid(voxel_size, mode="grow") + out_ds_roi = self.ds_roi.snap_to_grid(voxel_size, mode="grow") - self.write_size = self.chunk_shape_voxel*voxel_size + self.write_size = self.chunk_shape_voxel * voxel_size scheduling_block_size = self.write_size self.write_roi = daisy.Roi((0, 0, 0), scheduling_block_size) @@ -88,9 +83,9 @@ def _task_init(self): # with sub_roi, the coordinates are absolute # so we'd need to align total_roi to the write size too self.schedule_roi = self.schedule_roi.snap_to_grid( - self.write_size, mode='grow') - out_ds_roi = out_ds_roi.snap_to_grid( - self.write_size, mode='grow') + self.write_size, mode="grow" + ) + out_ds_roi = out_ds_roi.snap_to_grid(self.write_size, mode="grow") logger.info(f"out_ds_roi: {out_ds_roi}") logger.info(f"schedule_roi: {self.schedule_roi}") @@ -98,7 +93,7 @@ def _task_init(self): logger.info(f"voxel_size: {voxel_size}") if self.out_file is None: - self.out_file = '.'.join(self.in_file.split('.')[0:-1])+'.zarr' + self.out_file = ".".join(self.in_file.split(".")[0:-1]) + ".zarr" if self.out_ds_name is None: self.out_ds_name = self.in_ds_name @@ -113,9 +108,9 @@ def _task_init(self): dtype=self.in_ds.dtype, num_channels=num_channels, force_exact_write_size=True, - compressor={'id': 'blosc', 'clevel': 3}, + compressor={"id": "blosc", "clevel": 3}, delete=delete, - ) + ) def prepare_task(self): @@ -123,7 +118,16 @@ def prepare_task(self): logger.info( "Rechunking %s/%s to %s/%s with chunk_shape_voxel %s (write_size %s, scheduling %s)" - % (self.in_file, self.in_ds_name, self.out_file, self.out_ds_name, self.chunk_shape_voxel, self.write_size, self.write_roi)) + % ( + self.in_file, + self.in_ds_name, + self.out_file, + self.out_ds_name, + self.chunk_shape_voxel, + self.write_size, + self.write_roi, + ) + ) logger.info("ROI: %s" % self.schedule_roi) worker_filename = os.path.realpath(__file__) @@ -134,10 +138,10 @@ def prepare_task(self): read_roi=self.write_roi, write_roi=self.write_roi, check_fn=lambda b: self.check_fn(b), - ) + ) def _worker_impl(self, block): - '''Worker function implementation''' + """Worker function implementation""" self.out_ds[block.write_roi] = self.in_ds[block.write_roi] def check_fn(self, block): @@ -151,38 +155,42 @@ def check_fn(self, block): if __name__ == "__main__": - if len(sys.argv) > 1 and sys.argv[1] == 'run_worker': + if len(sys.argv) > 1 and sys.argv[1] == "run_worker": task = HDF2ZarrTask(config_file=sys.argv[2]) task.run_worker() else: - ap = argparse.ArgumentParser( - description="Create a zarr/N5 container from hdf.") - ap.add_argument("in_file", type=str, help='The input container') - ap.add_argument("in_ds_name", type=str, help='The name of the dataset') + ap = argparse.ArgumentParser(description="Create a zarr/N5 container from hdf.") + ap.add_argument("in_file", type=str, help="The input container") + ap.add_argument("in_ds_name", type=str, help="The name of the dataset") ap.add_argument( - "--out_file", type=str, default=None, - help='The output container, defaults to be the same as in_file+.zarr' - ) + "--out_file", + type=str, + default=None, + help="The output container, defaults to be the same as in_file+.zarr", + ) ap.add_argument( - "--out_ds_name", type=str, default=None, - help='The name of the dataset, defaults to be in_ds_name' - ) - ap.add_argument( - "--chunk_shape_voxel", type=int, help='The size of a chunk in voxels', - nargs='+', default=None - ) - ap.add_argument( - "--max_voxel_count", type=int, default=256*1024, - help='If chunk_shape_voxel is not given, use this value to calculate' - 'a near isotropic chunk shape', - ) + "--out_ds_name", + type=str, + default=None, + help="The name of the dataset, defaults to be in_ds_name", + ) ap.add_argument( - "--roi_offset", type=int, help='', - nargs='+', default=None) + "--chunk_shape_voxel", + type=int, + help="The size of a chunk in voxels", + nargs="+", + default=None, + ) ap.add_argument( - "--roi_shape", type=int, help='', - nargs='+', default=None) + "--max_voxel_count", + type=int, + default=256 * 1024, + help="If chunk_shape_voxel is not given, use this value to calculate" + "a near isotropic chunk shape", + ) + ap.add_argument("--roi_offset", type=int, help="", nargs="+", default=None) + ap.add_argument("--roi_shape", type=int, help="", nargs="+", default=None) config = HDF2ZarrTask.parse_args(ap) task = HDF2ZarrTask(config) diff --git a/examples/minimal_example/server.py b/examples/minimal_example/server.py index 1ca9014c..a7e6064f 100644 --- a/examples/minimal_example/server.py +++ b/examples/minimal_example/server.py @@ -1,7 +1,8 @@ import multiprocessing + # workaround for MacOS: # this needs to be set before importing any library that uses multiprocessing -multiprocessing.set_start_method('fork') +multiprocessing.set_start_method("fork") import logging import daisy import subprocess @@ -10,12 +11,7 @@ logging.basicConfig(level=logging.INFO) -def run_scheduler( - total_roi, - read_roi, - write_roi, - num_workers -): +def run_scheduler(total_roi, read_roi, write_roi, num_workers): dummy_task = daisy.Task( "dummy_task", @@ -51,9 +47,4 @@ def start_worker(): write_roi = daisy.Roi((0, 0, 0), (256, 256, 256)) read_roi = write_roi.grow((16, 16, 16), (16, 16, 16)) - run_scheduler( - total_roi, - read_roi, - write_roi, - num_workers=10 - ) + run_scheduler(total_roi, read_roi, write_roi, num_workers=10) diff --git a/examples/visualize.py b/examples/visualize.py index f782ab57..274f304b 100644 --- a/examples/visualize.py +++ b/examples/visualize.py @@ -12,73 +12,77 @@ parser = argparse.ArgumentParser() parser.add_argument( - '--file', - '-f', - type=str, - action='append', - help="The path to the container to show") + "--file", "-f", type=str, action="append", help="The path to the container to show" +) parser.add_argument( - '--datasets', - '-d', + "--datasets", + "-d", type=str, - nargs='+', - action='append', - help="The datasets in the container to show") + nargs="+", + action="append", + help="The datasets in the container to show", +) parser.add_argument( - '--graphs', - '-g', + "--graphs", + "-g", type=str, - nargs='+', - action='append', - help="The graphs in the container to show") + nargs="+", + action="append", + help="The graphs in the container to show", +) parser.add_argument( - '--no-browser', - '-n', + "--no-browser", + "-n", type=bool, - nargs='?', + nargs="?", default=False, const=True, - help="If set, do not open a browser, just print a URL") + help="If set, do not open a browser, just print a URL", +) args = parser.parse_args() -neuroglancer.set_server_bind_address('0.0.0.0') +neuroglancer.set_server_bind_address("0.0.0.0") viewer = neuroglancer.Viewer() + def to_slice(slice_str): - values = [int(x) for x in slice_str.split(':')] + values = [int(x) for x in slice_str.split(":")] if len(values) == 1: return values[0] return slice(*values) + def parse_ds_name(ds): - tokens = ds.split('[') + tokens = ds.split("[") if len(tokens) == 1: return ds, None ds, slices = tokens - slices = list(map(to_slice, slices.rstrip(']').split(','))) + slices = list(map(to_slice, slices.rstrip("]").split(","))) return ds, slices + class Project: def __init__(self, array, dim, value): self.array = array self.dim = dim self.value = value - self.shape = array.shape[:self.dim] + array.shape[self.dim + 1:] + self.shape = array.shape[: self.dim] + array.shape[self.dim + 1 :] self.dtype = array.dtype def __getitem__(self, key): - slices = key[:self.dim] + (self.value,) + key[self.dim:] + slices = key[: self.dim] + (self.value,) + key[self.dim :] ret = self.array[slices] return ret + def slice_dataset(a, slices): dims = a.roi.dims @@ -88,19 +92,21 @@ def slice_dataset(a, slices): if isinstance(s, slice): raise NotImplementedError("Slicing not yet implemented!") else: - index = (s - a.roi.get_begin()[d])//a.voxel_size[d] + index = (s - a.roi.get_begin()[d]) // a.voxel_size[d] a.data = Project(a.data, d, index) a.roi = daisy.Roi( - a.roi.get_begin()[:d] + a.roi.get_begin()[d + 1:], - a.roi.get_shape()[:d] + a.roi.get_shape()[d + 1:]) - a.voxel_size = a.voxel_size[:d] + a.voxel_size[d + 1:] + a.roi.get_begin()[:d] + a.roi.get_begin()[d + 1 :], + a.roi.get_shape()[:d] + a.roi.get_shape()[d + 1 :], + ) + a.voxel_size = a.voxel_size[:d] + a.voxel_size[d + 1 :] return a + def open_dataset(f, ds): original_ds = ds ds, slices = parse_ds_name(ds) - slices_str = original_ds[len(ds):] + slices_str = original_ds[len(ds) :] try: dataset_as = [] @@ -129,7 +135,9 @@ def open_dataset(f, ds): if a.roi.dims == 2: print("ROI is 2D, recruiting next channel to z dimension") - a.roi = daisy.Roi((0,) + a.roi.get_begin(), (a.shape[-3],) + a.roi.get_shape()) + a.roi = daisy.Roi( + (0,) + a.roi.get_begin(), (a.shape[-3],) + a.roi.get_shape() + ) a.voxel_size = daisy.Coordinate((1,) + a.voxel_size) if a.roi.dims == 4: @@ -143,7 +151,10 @@ def open_dataset(f, ds): return [(a, ds)] else: - return [([daisy.open_ds(f, f"{ds}/{key}") for key in zarr.open(f)[ds].keys()], ds)] + return [ + ([daisy.open_ds(f, f"{ds}/{key}") for key in zarr.open(f)[ds].keys()], ds) + ] + for f, datasets in zip(args.file, args.datasets): @@ -159,18 +170,12 @@ def open_dataset(f, ds): print(type(e), e) print("Didn't work, checking if this is multi-res...") - scales = glob.glob(os.path.join(f, ds, 's*')) + scales = glob.glob(os.path.join(f, ds, "s*")) if len(scales) == 0: print(f"Couldn't read {ds}, skipping...") raise e - print("Found scales %s" % ([ - os.path.relpath(s, f) - for s in scales - ],)) - a = [ - open_dataset(f, os.path.relpath(scale_ds, f)) - for scale_ds in scales - ] + print("Found scales %s" % ([os.path.relpath(s, f) for s in scales],)) + a = [open_dataset(f, os.path.relpath(scale_ds, f)) for scale_ds in scales] for a in dataset_as: arrays.append(a) @@ -185,8 +190,8 @@ def open_dataset(f, ds): graph_annotations = [] try: - ids = daisy.open_ds(f, graph + '-ids').data - loc = daisy.open_ds(f, graph + '-locations').data + ids = daisy.open_ds(f, graph + "-ids").data + loc = daisy.open_ds(f, graph + "-locations").data except: loc = daisy.open_ds(f, graph).data ids = None @@ -199,15 +204,15 @@ def open_dataset(f, ds): l = np.concatenate([[0], l]) graph_annotations.append( neuroglancer.EllipsoidAnnotation( - center=l[::-1], - radii=(5, 5, 5), - id=i)) + center=l[::-1], radii=(5, 5, 5), id=i + ) + ) graph_layer = neuroglancer.AnnotationLayer( - annotations=graph_annotations, - voxel_size=(1, 1, 1)) + annotations=graph_annotations, voxel_size=(1, 1, 1) + ) with viewer.txn() as s: - s.layers.append(name='graph', layer=graph_layer) + s.layers.append(name="graph", layer=graph_layer) url = str(viewer) print(url) @@ -215,4 +220,4 @@ def open_dataset(f, ds): webbrowser.open_new(url) print("Press ENTER to quit") -input() \ No newline at end of file +input() diff --git a/tests/test_client.py b/tests/test_client.py index d763a4cd..61feb739 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,9 +1,7 @@ import daisy import unittest import multiprocessing as mp -from daisy.messages import ( - AcquireBlock, ReleaseBlock, SendBlock, - ExceptionMessage) +from daisy.messages import AcquireBlock, ReleaseBlock, SendBlock, ExceptionMessage from daisy.tcp import TCPServer @@ -40,18 +38,14 @@ def run_test_server(self, block, conn): def test_basic(self): roi = daisy.Roi((0, 0, 0), (10, 10, 10)) task_id = 1 - block = daisy.Block( - roi, roi, roi, - block_id=1, - task_id=task_id) + block = daisy.Block(roi, roi, roi, block_id=1, task_id=task_id) parent_conn, child_conn = mp.Pipe() - server_process = mp.Process(target=self.run_test_server, - args=(block, child_conn)) + server_process = mp.Process( + target=self.run_test_server, args=(block, child_conn) + ) server_process.start() host, port = parent_conn.recv() - context = daisy.Context( - hostname=host, port=port, - task_id=task_id, worker_id=1) + context = daisy.Context(hostname=host, port=port, task_id=task_id, worker_id=1) client = daisy.Client(context=context) with client.acquire_block() as block: block.status = daisy.BlockStatus.SUCCESS diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 339f028d..c6b4f555 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -53,6 +53,7 @@ def task_1d(tmpdir): check_function=None, ) + @pytest.fixture def task_no_conflicts(tmpdir): # block ids: @@ -155,6 +156,7 @@ def overlapping_tasks(): ) return task_1, task_2 + def test_simple_no_conflicts(task_no_conflicts): scheduler = Scheduler([task_no_conflicts]) @@ -166,7 +168,7 @@ def test_simple_no_conflicts(task_no_conflicts): assert block.read_roi == expected_block.read_roi assert block.write_roi == expected_block.write_roi assert block.block_id == expected_block.block_id - + block = scheduler.acquire_block(task_no_conflicts.task_id) expected_block = Block( @@ -176,6 +178,7 @@ def test_simple_no_conflicts(task_no_conflicts): assert block.write_roi == expected_block.write_roi assert block.block_id == expected_block.block_id + def test_simple_acquire_block(task_1d): scheduler = Scheduler([task_1d]) block = scheduler.acquire_block(task_1d.task_id) diff --git a/tests/test_server.py b/tests/test_server.py index 6cde20b4..c7ef72ef 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,22 +4,24 @@ logging.basicConfig(level=logging.DEBUG) + class TestServer(unittest.TestCase): def test_basic(self): task = daisy.Task( - 'test_server_task', - total_roi=daisy.Roi((0,), (100,)), - read_roi=daisy.Roi((0,), (10,)), - write_roi=daisy.Roi((1,), (8,)), - process_function=lambda b: self.process_block(b), - check_function=None, - read_write_conflict=True, - fit='valid', - num_workers=1, - max_retries=2, - timeout=None) + "test_server_task", + total_roi=daisy.Roi((0,), (100,)), + read_roi=daisy.Roi((0,), (10,)), + write_roi=daisy.Roi((1,), (8,)), + process_function=lambda b: self.process_block(b), + check_function=None, + read_write_conflict=True, + fit="valid", + num_workers=1, + max_retries=2, + timeout=None, + ) server = daisy.Server() server.run_blockwise([task]) diff --git a/tests/tmpdir_test.py b/tests/tmpdir_test.py index a7cf0bbd..9893689c 100644 --- a/tests/tmpdir_test.py +++ b/tests/tmpdir_test.py @@ -34,7 +34,8 @@ class TmpDirTestCase(unittest.TestCase): tearDownClass should explicitly call the ``super`` method in the method definition. """ - _output_root = '' + + _output_root = "" _cleanup = True def path_to(self, *args): @@ -48,7 +49,8 @@ def path_to_cls(cls, *args): def setUpClass(cls): timestamp = datetime.now().isoformat() cls._output_root = mkdtemp( - prefix='daisy_{}_{}_'.format(cls.__name__, timestamp)) + prefix="daisy_{}_{}_".format(cls.__name__, timestamp) + ) def setUp(self): os.mkdir(self.path_to()) @@ -59,9 +61,9 @@ def tearDown(self): if self._cleanup: shutil.rmtree(path) else: - warn('Directory {} was not deleted'.format(path)) + warn("Directory {} was not deleted".format(path)) except OSError as e: - if '[Errno 2]' in str(e): + if "[Errno 2]" in str(e): pass else: raise @@ -72,13 +74,14 @@ def tearDownClass(cls): if cls._cleanup: os.rmdir(cls.path_to_cls()) else: - warn('Directory {} was not deleted'.format(cls.path_to_cls())) + warn("Directory {} was not deleted".format(cls.path_to_cls())) except OSError as e: - if '[Errno 39]' in str(e): + if "[Errno 39]" in str(e): warn( - 'Directory {} could not be deleted as it still had data ' - 'in it'.format(cls.path_to_cls())) - elif '[Errno 2]' in str(e): + "Directory {} could not be deleted as it still had data " + "in it".format(cls.path_to_cls()) + ) + elif "[Errno 2]" in str(e): pass else: raise