diff --git a/daisy/__init__.py b/daisy/__init__.py index 4aa01b07..d650b344 100644 --- a/daisy/__init__.py +++ b/daisy/__init__.py @@ -1,16 +1,16 @@ from __future__ import absolute_import -from .block import Block, BlockStatus # noqa -from .blocks import expand_roi_to_grid # noqa -from .blocks import expand_write_roi_to_grid # noqa -from .client import Client # noqa -from .context import Context # noqa -from .convenience import run_blockwise # noqa -from .coordinate import Coordinate # noqa -from .dependency_graph import DependencyGraph, BlockwiseDependencyGraph # noqa -from .logging import get_worker_log_basename, redirect_stdouterr # noqa -from .roi import Roi # noqa -from .scheduler import Scheduler # noqa -from .server import Server # noqa -from .task import Task # noqa -from .worker import Worker # noqa -from .worker_pool import WorkerPool # noqa +from .block import Block, BlockStatus # noqa +from .blocks import expand_roi_to_grid # noqa +from .blocks import expand_write_roi_to_grid # noqa +from .client import Client # noqa +from .context import Context # noqa +from .convenience import run_blockwise # noqa +from .coordinate import Coordinate # noqa +from .dependency_graph import DependencyGraph, BlockwiseDependencyGraph # noqa +from .logging import get_worker_log_basename, redirect_stdouterr # noqa +from .roi import Roi # noqa +from .scheduler import Scheduler # noqa +from .server import Server # noqa +from .task import Task # noqa +from .worker import Worker # noqa +from .worker_pool import WorkerPool # noqa diff --git a/daisy/block.py b/daisy/block.py index 48a3e47a..9cc92214 100644 --- a/daisy/block.py +++ b/daisy/block.py @@ -66,14 +66,8 @@ class Block(Freezable): The id of the Task that this block belongs to. Defaults to None. """ - def __init__( - self, - total_roi, - read_roi, - write_roi, - block_id=None, - task_id=None): + def __init__(self, total_roi, read_roi, write_roi, block_id=None, task_id=None): self.read_roi = read_roi self.write_roi = write_roi self.requested_write_roi = write_roi # save original write_roi @@ -88,13 +82,10 @@ def __init__( self.freeze() def copy(self): - return copy.deepcopy(self) def __compute_block_id(self, total_roi, write_roi, shift=None): - block_index = ( - write_roi.offset - total_roi.offset - ) / write_roi.shape + block_index = (write_roi.offset - total_roi.offset) / write_roi.shape # block_id will be the cantor number for this block index block_id = int(cantor_number(block_index)) @@ -102,7 +93,6 @@ def __compute_block_id(self, total_roi, write_roi, shift=None): return block_id def __repr__(self): - return "%s/%d with read ROI %s and write ROI %s" % ( self.block_id[0], self.block_id[1], diff --git a/daisy/block_bookkeeper.py b/daisy/block_bookkeeper.py index 173e8565..4f64196e 100644 --- a/daisy/block_bookkeeper.py +++ b/daisy/block_bookkeeper.py @@ -2,7 +2,6 @@ class BlockLog: - def __init__(self, block, stream): self.block = block self.stream = stream @@ -10,43 +9,44 @@ def __init__(self, block, stream): class BlockBookkeeper: - def __init__(self, processing_timeout=None): self.processing_timeout = processing_timeout self.sent_blocks = {} def notify_block_sent(self, block, stream): - '''Notify the bookkeeper that a block has been sent to a client (i.e., - a stream to the client).''' + """Notify the bookkeeper that a block has been sent to a client (i.e., + a stream to the client).""" - assert block.block_id not in self.sent_blocks, \ - f"Attempted to send block {block}, although it is already being " \ + assert block.block_id not in self.sent_blocks, ( + f"Attempted to send block {block}, although it is already being " f"processed by {self.sent_blocks[block.block_id].stream}" + ) self.sent_blocks[block.block_id] = BlockLog(block, stream) def notify_block_returned(self, block, stream): - '''Notify the bookkeeper that a block was returned.''' + """Notify the bookkeeper that a block was returned.""" - assert block.block_id in self.sent_blocks, \ - f"Block {block} was returned by {stream}, but is not in list " \ + assert block.block_id in self.sent_blocks, ( + f"Block {block} was returned by {stream}, but is not in list " "of sent blocks" + ) log = self.sent_blocks[block.block_id] block.started_processing = log.time_sent block.stopped_processing = time.time() - assert stream == log.stream, \ - f"Block {block} was returned by {stream}, but was sent to" \ - f"{log.stream}" + assert stream == log.stream, ( + f"Block {block} was returned by {stream}, but was sent to" f"{log.stream}" + ) del self.sent_blocks[block.block_id] def is_valid_return(self, block, stream): - '''Check whether the block from the given client (i.e., stream) is + """Check whether the block from the given client (i.e., stream) is expected to be returned from this client. This is to avoid double returning blocks that have already been returned as lost blocks, but - still come back from the client due to race conditions.''' + still come back from the client due to race conditions.""" # block was never sent or already returned if block.block_id not in self.sent_blocks: @@ -59,14 +59,13 @@ def is_valid_return(self, block, stream): return True def get_lost_blocks(self): - '''Return a list of blocks that were sent and are lost, either because + """Return a list of blocks that were sent and are lost, either because the stream to the client closed or the processing timed out. Those blocks are removed from the sent-list with the call of this - function.''' + function.""" lost_block_ids = [] for block_id, log in self.sent_blocks.items(): - # is the stream to the client still alive? if log.stream.closed(): lost_block_ids.append(block_id) @@ -78,7 +77,6 @@ def get_lost_blocks(self): lost_blocks = [] for block_id in lost_block_ids: - lost_block = self.sent_blocks[block_id].block lost_blocks.append(lost_block) del self.sent_blocks[block_id] diff --git a/daisy/blocks.py b/daisy/blocks.py index aedb8eef..3e1c266d 100644 --- a/daisy/blocks.py +++ b/daisy/blocks.py @@ -106,7 +106,6 @@ def compute_level_conflicts(self): prev_level_offset = None for level, level_offset in enumerate(self.level_offsets): - # get conflicts to previous level if prev_level_offset is not None and self.read_write_conflict: conflict_offsets = self.get_conflict_offsets( @@ -121,13 +120,11 @@ def compute_level_conflicts(self): return level_conflict_offsets def create_dependency_graph(self): - blocks = [] for level_offset, level_conflicts in zip( self.level_offsets, self.level_conflicts ): - # all block offsets of the current level (relative to total ROI # start) block_dim_offsets = [ @@ -138,16 +135,11 @@ def create_dependency_graph(self): ] # TODO: can we do this part lazily? This might be a lot of # Coordinates - block_offsets = [ - Coordinate(o) - for o in product(*block_dim_offsets)] + block_offsets = [Coordinate(o) for o in product(*block_dim_offsets)] # convert to global coordinates block_offsets = [ - o + ( - self.total_roi.get_begin() - - self.block_read_roi.get_begin() - ) + o + (self.total_roi.get_begin() - self.block_read_roi.get_begin()) for o in block_offsets ] @@ -173,12 +165,8 @@ def compute_level_stride(self): self.block_write_roi ), "Read ROI must contain write ROI." - context_ul = ( - self.block_write_roi.get_begin() - - self.block_read_roi.get_begin()) - context_lr = ( - self.block_read_roi.get_end() - - self.block_write_roi.get_end()) + context_ul = self.block_write_roi.get_begin() - self.block_read_roi.get_begin() + context_lr = self.block_read_roi.get_end() - self.block_write_roi.get_end() max_context = Coordinate( (max(ul, lr) for ul, lr in zip(context_ul, context_lr)) @@ -195,14 +183,14 @@ def compute_level_stride(self): # to avoid overlapping write ROIs, increase the stride to the next # multiple of write shape write_shape = self.block_write_roi.get_shape() - level_stride = Coordinate(( - ((level - 1) // w + 1) * w - for level, w in zip(min_level_stride, write_shape) - )) + level_stride = Coordinate( + ( + ((level - 1) // w + 1) * w + for level, w in zip(min_level_stride, write_shape) + ) + ) - logger.debug( - "final level stride (multiples of write size) is %s", - level_stride) + logger.debug("final level stride (multiples of write size) is %s", level_stride) return level_stride @@ -219,26 +207,16 @@ def compute_level_offsets(self): ) dim_offsets = [ - range(0, e, step) - for e, step in zip(self.level_stride, write_stride) + range(0, e, step) for e, step in zip(self.level_stride, write_stride) ] - level_offsets = list( - reversed([ - Coordinate(o) - for o in product(*dim_offsets) - ]) - ) + level_offsets = list(reversed([Coordinate(o) for o in product(*dim_offsets)])) logger.debug("level offsets: %s", level_offsets) return level_offsets - def get_conflict_offsets( - self, - level_offset, - prev_level_offset, - level_stride): + def get_conflict_offsets(self, level_offset, prev_level_offset, level_stride): """Get the offsets to all previous level blocks that are in conflict with the current level blocks.""" @@ -250,22 +228,15 @@ def get_conflict_offsets( for op, ls in zip(offset_to_prev, level_stride) ] - conflict_offsets = [ - Coordinate(o) - for o in product(*conflict_dim_offsets) - ] - logger.debug( - "conflict offsets to previous level: %s", - conflict_offsets) + conflict_offsets = [Coordinate(o) for o in product(*conflict_dim_offsets)] + logger.debug("conflict offsets to previous level: %s", conflict_offsets) return conflict_offsets def enumerate_dependencies(self, conflict_offsets, block_offsets): - inclusion_criteria = { "valid": lambda b: self.total_roi.contains(b.read_roi), - "overhang": lambda b: self.total_roi.contains( - b.write_roi.get_begin()), + "overhang": lambda b: self.total_roi.contains(b.write_roi.get_begin()), "shrink": lambda b: self.shrink_possible(b), }[self.fit] @@ -278,7 +249,6 @@ def enumerate_dependencies(self, conflict_offsets, block_offsets): blocks = [] for block_offset in block_offsets: - # create a block shifted by the current offset block = Block( self.total_roi, @@ -294,7 +264,6 @@ def enumerate_dependencies(self, conflict_offsets, block_offsets): # get all blocks in conflict with the current block conflicts = [] for conflict_offset in conflict_offsets: - conflict = Block( self.total_roi, block.read_roi + conflict_offset, @@ -314,7 +283,6 @@ def enumerate_dependencies(self, conflict_offsets, block_offsets): return blocks def shrink_possible(self, block): - if not self.total_roi.contains(block.write_roi.get_begin()): return False @@ -388,10 +356,7 @@ def get_subgraph_blocks(self, sub_roi): def expand_roi_to_grid(sub_roi, total_roi, read_roi, write_roi): """Expands given roi so that its write region is aligned to write_roi""" - offset = ( - write_roi.get_begin() + - total_roi.get_begin() - - read_roi.get_begin()) + offset = write_roi.get_begin() + total_roi.get_begin() - read_roi.get_begin() begin = sub_roi.get_begin() - offset end = sub_roi.get_end() - offset @@ -405,10 +370,7 @@ def expand_roi_to_grid(sub_roi, total_roi, read_roi, write_roi): def expand_request_roi_to_grid(req_roi, total_roi, read_roi, write_roi): """Expands given roi so that its write region is aligned to write_roi""" - offset = ( - write_roi.get_begin() + - total_roi.get_begin() - - read_roi.get_begin()) + offset = write_roi.get_begin() + total_roi.get_begin() - read_roi.get_begin() begin = req_roi.get_begin() - offset end = req_roi.get_end() - offset @@ -430,7 +392,5 @@ def expand_write_roi_to_grid(roi, write_roi): -(-roi.get_end() // write_roi.get_shape()), ) # `ceildiv` - roi = Roi( - roi[0] * write_roi.get_shape(), - (roi[1] - roi[0]) * write_roi.get_shape()) + roi = Roi(roi[0] * write_roi.get_shape(), (roi[1] - roi[0]) * write_roi.get_shape()) return roi diff --git a/daisy/cl_monitor.py b/daisy/cl_monitor.py index dddffc72..55a0d188 100644 --- a/daisy/cl_monitor.py +++ b/daisy/cl_monitor.py @@ -6,12 +6,12 @@ class TqdmLoggingHandler: - '''A logging handler that uses ``tqdm.tqdm.write`` in ``emit()``, such that + """A logging handler that uses ``tqdm.tqdm.write`` in ``emit()``, such that logging doesn't interfere with tqdm's progress bar. Heavily inspired by the fantastic https://github.com/EpicWink/tqdm-logging-wrapper/ - ''' + """ def __init__(self, handler): self.handler = handler @@ -27,27 +27,25 @@ def emit(self, record): class TaskSummary: - def __init__(self, state): - self.block_failures = [] self.state = state class BlockFailure: - def __init__(self, block, exception, worker_id): self.block = block self.exception = exception self.worker_id = worker_id def __repr__(self): - return f"block {self.block.block_id[1]} in worker " \ + return ( + f"block {self.block.block_id[1]} in worker " f"{self.worker_id} with exception {repr(self.exception)}" + ) class CLMonitor(ServerObserver): - def __init__(self, server): super().__init__(server) self.progresses = {} @@ -56,13 +54,13 @@ def __init__(self, server): self._wrap_logging_handlers() def _wrap_logging_handlers(self): - '''This adds a TqdmLoggingHandler around each logging handler that has + """This adds a TqdmLoggingHandler around each logging handler that has a TTY stream attached to it, so that logging doesn't interfere with the progress bar. Heavily inspired by the fantastic https://github.com/EpicWink/tqdm-logging-wrapper/ - ''' + """ logger = logging.root for i in range(len(logger.handlers)): @@ -70,7 +68,6 @@ def _wrap_logging_handlers(self): logger.handlers[i] = TqdmLoggingHandler(logger.handlers[i]) def _is_tty_stream_handler(self, handler): - return ( hasattr(handler, "stream") and hasattr(handler.stream, "isatty") @@ -84,17 +81,15 @@ def on_release_block(self, task_id, task_state): self._update_state(task_id, task_state) def on_block_failure(self, block, exception, context): - task_id = block.block_id[0] self.summaries[task_id].block_failures.append( - BlockFailure(block, exception, context['worker_id'])) + BlockFailure(block, exception, context["worker_id"]) + ) def on_task_start(self, task_id, task_state): - self.summaries[task_id] = TaskSummary(task_state) def on_task_done(self, task_id, task_state): - if task_id not in self.summaries: self.summaries[task_id] = TaskSummary(task_state) else: @@ -111,7 +106,6 @@ def on_task_done(self, task_id, task_state): self.progresses[task_id].close() def on_server_exit(self): - for task_id, progress in self.progresses.items(): progress.close() @@ -122,7 +116,6 @@ def on_server_exit(self): max_entries = 100 for task_id, summary in self.summaries.items(): - num_block_failures = len(summary.block_failures) print() @@ -130,8 +123,10 @@ def on_server_exit(self): print() state = summary.state print(f" num blocks : {state.total_block_count}") - print(f" completed ✔: {state.completed_count} " - f"(skipped {state.skipped_count})") + print( + f" completed ✔: {state.completed_count} " + f"(skipped {state.skipped_count})" + ) print(f" failed ✗: {state.failed_count}") print(f" orphaned ∅: {state.orphaned_count}") print() @@ -151,8 +146,8 @@ def on_server_exit(self): print() for block_failure in summary.block_failures[:10]: log_basename = daisy_logging.get_worker_log_basename( - block_failure.worker_id, - block_failure.block.block_id[0]) + block_failure.worker_id, block_failure.block.block_id[0] + ) print(f" {log_basename}.err / .out") if num_block_failures > 10: print(" ...") @@ -161,22 +156,21 @@ def on_server_exit(self): print(" all blocks processed successfully") def _update_state(self, task_id, task_state): - if task_id not in self.progresses: total = task_state.total_block_count self.progresses[task_id] = tqdm_auto( - total=total, - desc=task_id + " ▶", - unit='blocks', - leave=True) - - self.progresses[task_id].set_postfix({ - '⧗': task_state.pending_count, - '▶': task_state.processing_count, - '✔': task_state.completed_count, - '✗': task_state.failed_count, - '∅': task_state.orphaned_count - }) + total=total, desc=task_id + " ▶", unit="blocks", leave=True + ) + + self.progresses[task_id].set_postfix( + { + "⧗": task_state.pending_count, + "▶": task_state.processing_count, + "✔": task_state.completed_count, + "✗": task_state.failed_count, + "∅": task_state.orphaned_count, + } + ) completed = task_state.completed_count delta = completed - self.progresses[task_id].n diff --git a/daisy/context.py b/daisy/context.py index 35f20c02..1f6c4744 100644 --- a/daisy/context.py +++ b/daisy/context.py @@ -5,64 +5,51 @@ logger = logging.getLogger(__name__) -class Context(): - - ENV_VARIABLE = 'DAISY_CONTEXT' +class Context: + ENV_VARIABLE = "DAISY_CONTEXT" def __init__(self, **kwargs): - self.__dict = dict(**kwargs) def copy(self): - return copy.deepcopy(self) def to_env(self): - - return ':'.join('%s=%s' % (k, v) for k, v in self.__dict.items()) + return ":".join("%s=%s" % (k, v) for k, v in self.__dict.items()) def __setitem__(self, k, v): - k = str(k) v = str(v) - if '=' in k or ':' in k: + if "=" in k or ":" in k: raise RuntimeError("Context variables must not contain = or :.") - if '=' in v or ':' in v: + if "=" in v or ":" in v: raise RuntimeError("Context values must not contain = or :.") self.__dict[k] = v def __getitem__(self, k): - return self.__dict[k] def get(self, k, v=None): - return self.__dict.get(k, v) def __repr__(self): - return self.to_env() @staticmethod def from_env(): - try: - - tokens = os.environ[Context.ENV_VARIABLE].split(':') + tokens = os.environ[Context.ENV_VARIABLE].split(":") except KeyError: - - logger.error( - "%s environment variable not found!", - Context.ENV_VARIABLE) + logger.error("%s environment variable not found!", Context.ENV_VARIABLE) raise context = Context() for token in tokens: - k, v = token.split('=') + k, v = token.split("=") context[k] = v return context diff --git a/daisy/coordinate.py b/daisy/coordinate.py index a05770bf..c37998ea 100644 --- a/daisy/coordinate.py +++ b/daisy/coordinate.py @@ -1 +1 @@ -from funlib.geometry import Coordinate # noqa \ No newline at end of file +from funlib.geometry import Coordinate # noqa diff --git a/daisy/dependency_graph.py b/daisy/dependency_graph.py index 185fb551..86b4e169 100644 --- a/daisy/dependency_graph.py +++ b/daisy/dependency_graph.py @@ -160,9 +160,7 @@ def inclusion_criteria(self): # TODO: Can't we remove this entirely by pre computing the write_roi inclusion_criteria = { "valid": lambda b: self.total_write_roi.contains(b.write_roi), - "overhang": lambda b: self.total_write_roi.contains( - b.write_roi.begin - ), + "overhang": lambda b: self.total_write_roi.contains(b.write_roi.begin), "shrink": lambda b: self.shrink_possible(b), }[self.fit] return inclusion_criteria @@ -202,12 +200,12 @@ def _num_level_blocks(self, level): level_offset, self._level_stride, num_blocks, - axis_blocks) + axis_blocks, + ) return num_blocks def level_blocks(self, level): - for block_offset in self._compute_level_block_offsets(level): block = Block( self.total_read_roi, @@ -230,9 +228,7 @@ def root_gen(self): def _block_offset(self, block): # The block offset is the offset of the read roi relative to total roi - block_offset = ( - block.read_roi.offset - - self.total_read_roi.offset) + block_offset = block.read_roi.offset - self.total_read_roi.offset return block_offset def _level(self, block): @@ -306,7 +302,6 @@ def upstream(self, block): return conflicts def enumerate_all_dependencies(self): - self._level_block_offsets = self.compute_level_block_offsets() for level in range(self.num_levels): @@ -332,12 +327,8 @@ def compute_level_stride(self) -> Coordinate: self.block_write_roi ), "Read ROI must contain write ROI." - context_ul = ( - self.block_write_roi.begin - - self.block_read_roi.begin) - context_lr = ( - self.block_read_roi.end - - self.block_write_roi.end) + context_ul = self.block_write_roi.begin - self.block_read_roi.begin + context_lr = self.block_read_roi.end - self.block_write_roi.end max_context = Coordinate( (max(ul, lr) for ul, lr in zip(context_ul, context_lr)) @@ -354,10 +345,12 @@ def compute_level_stride(self) -> Coordinate: # to avoid overlapping write ROIs, increase the stride to the next # multiple of write shape write_shape = self.block_write_roi.shape - level_stride = Coordinate(( - ((level - 1) // w + 1) * w - for level, w in zip(min_level_stride, write_shape) - )) + level_stride = Coordinate( + ( + ((level - 1) // w + 1) * w + for level, w in zip(min_level_stride, write_shape) + ) + ) # Handle case where min_level_stride > total_write_roi. # This case leads to levels with no blocks in them. This makes @@ -373,11 +366,10 @@ def compute_level_stride(self) -> Coordinate: ) % self.block_write_roi.shape level_stride = Coordinate( - (min(a, b) for a, b in zip(level_stride, write_roi_shape))) + (min(a, b) for a, b in zip(level_stride, write_roi_shape)) + ) - logger.debug( - "final level stride (multiples of write size) is %s", - level_stride) + logger.debug("final level stride (multiples of write size) is %s", level_stride) return level_stride @@ -395,15 +387,10 @@ def compute_level_offsets(self) -> List[Coordinate]: ) dim_offsets = [ - range(0, e, step) - for e, step in zip(self._level_stride, write_stride) + range(0, e, step) for e, step in zip(self._level_stride, write_stride) ] - level_offsets = list( - reversed([ - Coordinate(o) - for o in product(*dim_offsets)]) - ) + level_offsets = list(reversed([Coordinate(o) for o in product(*dim_offsets)])) logger.debug("level offsets: %s", level_offsets) @@ -418,7 +405,6 @@ def compute_level_conflicts(self) -> List[List[Coordinate]]: prev_level_offset = None for level, level_offset in enumerate(self._level_offsets): - # get conflicts to previous level if prev_level_offset is not None and self.read_write_conflict: conflict_offsets = self.get_conflict_offsets( @@ -451,10 +437,7 @@ def _compute_level_block_offsets(self, level): block_offset = Coordinate(offset) # convert to global coordinates - block_offset += ( - self.total_read_roi.begin - - self.block_read_roi.begin - ) + block_offset += self.total_read_roi.begin - self.block_read_roi.begin yield block_offset def compute_level_block_offsets(self) -> List[List[Coordinate]]: @@ -465,18 +448,11 @@ def compute_level_block_offsets(self) -> List[List[Coordinate]]: level_block_offsets = [] for level in range(self.num_levels): - - level_block_offsets.append( - list( - self._compute_level_block_offsets(level))) + level_block_offsets.append(list(self._compute_level_block_offsets(level))) return level_block_offsets - def get_conflict_offsets( - self, - level_offset, - prev_level_offset, - level_stride): + def get_conflict_offsets(self, level_offset, prev_level_offset, level_stride): """Get the offsets to all previous level blocks that are in conflict with the current level blocks.""" @@ -495,18 +471,12 @@ def get_offsets(op, ls): get_offsets(op, ls) for op, ls in zip(offset_to_prev, level_stride) ] - conflict_offsets = [ - Coordinate(o) - for o in product(*conflict_dim_offsets) - ] - logger.debug( - "conflict offsets to previous level: %s", - conflict_offsets) + conflict_offsets = [Coordinate(o) for o in product(*conflict_dim_offsets)] + logger.debug("conflict offsets to previous level: %s", conflict_offsets) return conflict_offsets def shrink_possible(self, block): - return self.total_write_roi.contains(block.write_roi.begin) def shrink(self, block): @@ -536,7 +506,8 @@ def get_subgraph_blocks(self, sub_roi, read_roi=False): # only need to check if a blocks read_roi overlaps with sub_roi. # This is the same behavior as when we want write_roi overlap sub_roi = sub_roi.grow( - self.read_write_context[0], self.read_write_context[1]) + self.read_write_context[0], self.read_write_context[1] + ) # TODO: handle unsatisfiable sub_rois # i.e. sub_roi is outside of *total_write_roi @@ -690,8 +661,7 @@ def __enumerate_all_dependencies(self): "Block dependency %s is not found for task %s." % (upstream_block.block_id, task_id) ) - self._downstream[upstream_block.block_id].add( - block.block_id) + self._downstream[upstream_block.block_id].add(block.block_id) self._upstream[block.block_id].add(upstream_block.block_id) # enumerate all of the upstream / downstream dependencies @@ -713,7 +683,5 @@ def __enumerate_all_dependencies(self): "Block dependency %s is not found for task %s." % (upstream_block.block_id, task_id) ) - self._downstream[upstream_block.block_id].add( - block.block_id) - self._upstream[block.block_id].add( - upstream_block.block_id) + self._downstream[upstream_block.block_id].add(block.block_id) + self._upstream[block.block_id].add(upstream_block.block_id) diff --git a/daisy/freezable.py b/daisy/freezable.py index 93572868..fe772384 100644 --- a/daisy/freezable.py +++ b/daisy/freezable.py @@ -1,11 +1,9 @@ class Freezable(object): - __isfrozen = False def __setattr__(self, key, value): if self.__isfrozen and not hasattr(self, key): - raise TypeError( - "%r is frozen, you can't add attributes to it" % self) + raise TypeError("%r is frozen, you can't add attributes to it" % self) object.__setattr__(self, key, value) def freeze(self): diff --git a/daisy/logging.py b/daisy/logging.py index 8a917b58..0b1d13ea 100644 --- a/daisy/logging.py +++ b/daisy/logging.py @@ -4,16 +4,16 @@ # default log dir -LOG_BASEDIR = Path('./daisy_logs') +LOG_BASEDIR = Path("./daisy_logs") def set_log_basedir(path): - '''Set the base directory for logging (indivudal worker logs and detailed + """Set the base directory for logging (indivudal worker logs and detailed task summaries). If set to ``None``, all logging will be shown on the command line (which can get very messy). Default is ``./daisy_logs``. - ''' + """ global LOG_BASEDIR @@ -24,7 +24,7 @@ def set_log_basedir(path): def get_worker_log_basename(worker_id, task_id=None): - '''Get the basename of log files for individual workers.''' + """Get the basename of log files for individual workers.""" if LOG_BASEDIR is None: return None @@ -32,17 +32,17 @@ def get_worker_log_basename(worker_id, task_id=None): basename = LOG_BASEDIR if task_id is not None: basename /= task_id - basename /= f'worker_{worker_id}' + basename /= f"worker_{worker_id}" return basename -def redirect_stdouterr(basename, mode='w'): - '''Redirect stdout and stderr of the current process to files:: +def redirect_stdouterr(basename, mode="w"): + """Redirect stdout and stderr of the current process to files:: - .out - .err - ''' + .out + .err + """ if basename is None: return @@ -53,18 +53,11 @@ def redirect_stdouterr(basename, mode='w'): logdir = basename.parent logdir.mkdir(parents=True, exist_ok=True) - sys.stdout = _file_reopen( - basename.with_suffix('.out'), - mode, - sys.__stdout__) - sys.stderr = _file_reopen( - basename.with_suffix('.err'), - mode, - sys.__stderr__) + sys.stdout = _file_reopen(basename.with_suffix(".out"), mode, sys.__stdout__) + sys.stderr = _file_reopen(basename.with_suffix(".err"), mode, sys.__stderr__) def _file_reopen(filename, mode, file_obj): - new = open(filename, mode) newfd = new.fileno() targetfd = file_obj.fileno() diff --git a/daisy/messages/block_failed.py b/daisy/messages/block_failed.py index 30cfef00..08386e9c 100644 --- a/daisy/messages/block_failed.py +++ b/daisy/messages/block_failed.py @@ -2,8 +2,6 @@ class BlockFailed(ClientException): - def __init__(self, exception, block, context): - super().__init__(exception, context) self.block = block diff --git a/daisy/messages/message.py b/daisy/messages/message.py index 454bf057..cbc362c8 100644 --- a/daisy/messages/message.py +++ b/daisy/messages/message.py @@ -9,13 +9,13 @@ class Message(TCPMessage): class ExceptionMessage(Message): - '''A message representing an exception. + """A message representing an exception. Args: exception (:class:`Exception`): The exception to wrap into this message. - ''' + """ def __init__(self, exception): self.exception = exception diff --git a/daisy/ready_surface.py b/daisy/ready_surface.py index 928376e0..6fe39289 100644 --- a/daisy/ready_surface.py +++ b/daisy/ready_surface.py @@ -53,9 +53,7 @@ def mark_success(self, node): # check if any downstream nodes need to be added to the boundary for down_node in self.downstream(node): if not self.__add_to_boundary(down_node): - if all( - up_node in self.surface - for up_node in self.upstream(down_node)): + if all(up_node in self.surface for up_node in self.upstream(down_node)): new_ready_nodes.append(down_node) # check if any of the upstream nodes can be removed from surface diff --git a/daisy/roi.py b/daisy/roi.py index 85054d70..c4b04dbc 100644 --- a/daisy/roi.py +++ b/daisy/roi.py @@ -1 +1 @@ -from funlib.geometry import Roi # noqa \ No newline at end of file +from funlib.geometry import Roi # noqa diff --git a/daisy/scheduler.py b/daisy/scheduler.py index b2cf866a..d904514b 100644 --- a/daisy/scheduler.py +++ b/daisy/scheduler.py @@ -82,7 +82,6 @@ def acquire_block(self, task_id): while True: block = self.task_queues[task_id].get_next() if block is not None: - # update states self.task_states[task_id].ready_count -= 1 self.task_states[task_id].processing_count += 1 @@ -90,19 +89,17 @@ def acquire_block(self, task_id): pre_check_ret = self.__precheck(block) if pre_check_ret: logger.debug( - "Skipping block (%s); already processed.", - block.block_id) + "Skipping block (%s); already processed.", block.block_id + ) block.status = BlockStatus.SUCCESS self.task_states[task_id].skipped_count += 1 # adding block so release_block() can remove it - self.task_queues[task_id].processing_blocks.add( - block.block_id) + self.task_queues[task_id].processing_blocks.add(block.block_id) self.release_block(block) continue else: self.task_states[task_id].started = True - self.task_queues[task_id].processing_blocks.add( - block.block_id) + self.task_queues[task_id].processing_blocks.add(block.block_id) return block else: @@ -158,15 +155,13 @@ def release_block(self, block): else: raise RuntimeError( f"Unexpected status for released block: {block.status} {block}" - ) + ) def __init_task(self, task): if task.task_id not in self.task_map: self.task_map[task.task_id] = task num_blocks = self.dependency_graph.num_blocks(task.task_id) - self.task_states[ - task.task_id - ].total_block_count = num_blocks + self.task_states[task.task_id].total_block_count = num_blocks for upstream_task in task.requires(): self.__init_task(upstream_task) @@ -179,8 +174,7 @@ def __queue_ready_block(self, block, index=None): self.task_states[block.task_id].ready_count += 1 def __remove_from_processing_blocks(self, block): - self.task_queues[block.task_id].processing_blocks.remove( - block.block_id) + self.task_queues[block.task_id].processing_blocks.remove(block.block_id) self.task_states[block.task_id].processing_count -= 1 def __update_ready_queue(self, ready_blocks): @@ -200,6 +194,5 @@ def __precheck(self, block): else: return False except Exception: - logger.exception( - f"pre_check() exception for block {block.block_id}") + logger.exception(f"pre_check() exception for block {block.block_id}") return False diff --git a/daisy/server_observer.py b/daisy/server_observer.py index 48edf6ca..d530c53e 100644 --- a/daisy/server_observer.py +++ b/daisy/server_observer.py @@ -1,7 +1,5 @@ class ServerObserver: - def __init__(self, server): - self.server = server server.register_observer(self) @@ -25,7 +23,6 @@ def on_server_exit(self): class ServerObservee: - def __init__(self): self.observers = [] diff --git a/daisy/task_worker_pools.py b/daisy/task_worker_pools.py index cf5596cc..49e41fe8 100644 --- a/daisy/task_worker_pools.py +++ b/daisy/task_worker_pools.py @@ -7,9 +7,7 @@ class TaskWorkerPools(ServerObserver): - def __init__(self, tasks, server, max_block_failures=3): - super().__init__(server) logger.debug("Creating worker pools") @@ -17,39 +15,36 @@ def __init__(self, tasks, server, max_block_failures=3): task.task_id: WorkerPool( task.spawn_worker_function, Context( - hostname=server.hostname, - port=server.port, - task_id=task.task_id)) + hostname=server.hostname, port=server.port, task_id=task.task_id + ), + ) for task in tasks } self.max_block_failures = max_block_failures self.failure_counts = {} def recruit_workers(self, tasks): - for task_id, worker_pool in self.worker_pools.items(): if task_id in tasks: logger.debug( "Setting number of workers for task %s to %d", task_id, - tasks[task_id].num_workers) + tasks[task_id].num_workers, + ) worker_pool.set_num_workers(tasks[task_id].num_workers) def stop(self): - logger.debug("Stopping all workers") for worker_pool in self.worker_pools.values(): worker_pool.stop() def check_worker_health(self): - for worker_pool in self.worker_pools.values(): worker_pool.check_for_errors() def on_block_failure(self, block, exception, context): - - task_id = context['task_id'] - worker_id = int(context['worker_id']) + task_id = context["task_id"] + worker_id = int(context["worker_id"]) if task_id not in self.failure_counts: self.failure_counts[task_id] = {} @@ -60,10 +55,9 @@ def on_block_failure(self, block, exception, context): self.failure_counts[task_id][worker_id] += 1 if self.failure_counts[task_id][worker_id] > self.max_block_failures: - logger.error( - "Worker %s failed too many times, restarting this worker...", - context) + "Worker %s failed too many times, restarting this worker...", context + ) self.failure_counts[task_id][worker_id] = 0 worker_pool = self.worker_pools[task_id] diff --git a/daisy/tcp/internal_messages.py b/daisy/tcp/internal_messages.py index f067ff35..01dbddf7 100644 --- a/daisy/tcp/internal_messages.py +++ b/daisy/tcp/internal_messages.py @@ -2,22 +2,25 @@ class InternalTCPMessage(TCPMessage): - '''TCP messages used only between :class:`TCPServer` and :class:`TCPClient` + """TCP messages used only between :class:`TCPServer` and :class:`TCPClient` for internal communication. - ''' + """ + pass class NotifyClientDisconnect(InternalTCPMessage): - '''Message to be sent from a :class:`TCPClient` to :class:`TCPServer` to + """Message to be sent from a :class:`TCPClient` to :class:`TCPServer` to initiate a disconnect. - ''' + """ + pass class AckClientDisconnect(InternalTCPMessage): - '''Message to be sent from a :class:`TCPServer` to :class:`TCPClient` to + """Message to be sent from a :class:`TCPServer` to :class:`TCPClient` to confirm a disconnect, i.e., the server will no longer listen to messages received from this client and the associated stream can be closed. - ''' + """ + pass diff --git a/daisy/tcp/io_looper.py b/daisy/tcp/io_looper.py index fe303579..c7b81a92 100644 --- a/daisy/tcp/io_looper.py +++ b/daisy/tcp/io_looper.py @@ -7,7 +7,7 @@ class IOLooper: - '''Base class for every class that needs access to tornado's IOLoop in a + """Base class for every class that needs access to tornado's IOLoop in a separate thread. Attributes: @@ -16,7 +16,7 @@ class IOLooper: The IO loop to be used in subclasses. Will run in a singleton thread per process. - ''' + """ threads = {} ioloops = {} @@ -27,22 +27,19 @@ def clear(): IOLooper.ioloops = {} def __init__(self): - pid = os.getpid() if pid not in IOLooper.threads: - logger.debug("Creating new IOLoop for process %d...", pid) self.ioloop = tornado.ioloop.IOLoop() self.ioloops[pid] = self.ioloop logger.debug("Starting io loop for process %d...", pid) IOLooper.threads[pid] = threading.Thread( - target=self.ioloop.start, - daemon=True) + target=self.ioloop.start, daemon=True + ) IOLooper.threads[pid].start() else: - logger.debug("Reusing IOLoop for process %d...", pid) self.ioloop = self.ioloops[pid] diff --git a/daisy/tcp/tcp_client.py b/daisy/tcp/tcp_client.py index 77ee3fcf..926d77e9 100644 --- a/daisy/tcp/tcp_client.py +++ b/daisy/tcp/tcp_client.py @@ -1,9 +1,7 @@ from .exceptions import NotConnected from .io_looper import IOLooper from .tcp_stream import TCPStream -from .internal_messages import ( - AckClientDisconnect, - NotifyClientDisconnect) +from .internal_messages import AckClientDisconnect, NotifyClientDisconnect import logging import queue import time @@ -13,7 +11,7 @@ class TCPClient(IOLooper): - '''A TCP client to handle client-server communication through + """A TCP client to handle client-server communication through :class:`TCPMessage` objects. Args: @@ -22,10 +20,9 @@ class TCPClient(IOLooper): port (int): The hostname and port of the :class:`TCPServer` to connect to. - ''' + """ def __init__(self, host, port): - super().__init__() logger.debug("Creating new TCP client...") @@ -40,26 +37,25 @@ def __init__(self, host, port): self.connect() def __del__(self): - if self.connected(): self.disconnect() def connect(self): - '''Connect to the server and start the message receive event loop.''' + """Connect to the server and start the message receive event loop.""" logger.debug("Connecting to server at %s:%d...", self.host, self.port) self.ioloop.add_callback(self._connect) while not self.connected(): self._check_for_errors() - time.sleep(.1) + time.sleep(0.1) logger.debug("...connected") self.ioloop.add_callback(self._receive) def disconnect(self): - '''Gracefully close the connection to the server.''' + """Gracefully close the connection to the server.""" if not self.connected(): logger.warn("Called disconnect() on disconnected client") @@ -69,23 +65,23 @@ def disconnect(self): self.stream.send_message(NotifyClientDisconnect()) while self.connected(): - time.sleep(.1) + time.sleep(0.1) logger.debug("Disconnected") def connected(self): - '''Check whether this client has a connection to the server.''' + """Check whether this client has a connection to the server.""" return self.stream is not None def send_message(self, message): - '''Send a message to the server. + """Send a message to the server. Args: message (:class:`TCPMessage`): Message to send over to the server. - ''' + """ self._check_for_errors() @@ -95,7 +91,7 @@ def send_message(self, message): self.stream.send_message(message) def get_message(self, timeout=None): - '''Get a message that was sent to this client. + """Get a message that was sent to this client. Args: @@ -104,7 +100,7 @@ def get_message(self, timeout=None): If set, wait up to `timeout` seconds for a message to arrive. If no message is available after the timeout, returns ``None``. If not set, wait until a message arrived. - ''' + """ self._check_for_errors() @@ -112,15 +108,12 @@ def get_message(self, timeout=None): raise NotConnected() try: - return self.message_queue.get(block=True, timeout=timeout) except queue.Empty: - return None def _check_for_errors(self): - try: exception = self.exception_queue.get(block=False) raise exception @@ -128,7 +121,7 @@ def _check_for_errors(self): return async def _connect(self): - '''Async method to connect to the TCPServer.''' + """Async method to connect to the TCPServer.""" try: stream = await self.client.connect(self.host, self.port) @@ -139,19 +132,16 @@ async def _connect(self): self.stream = TCPStream(stream, (self.host, self.port)) async def _receive(self): - '''Loop that receives messages from the server.''' + """Loop that receives messages from the server.""" logger.debug("Entering receive loop") while self.connected(): - try: - # raises StreamClosedError message = await self.stream._get_message() if isinstance(message, AckClientDisconnect): - # server acknowledged disconnect, close connection on # our side and break out of event loop try: @@ -161,11 +151,9 @@ async def _receive(self): return else: - self.message_queue.put(message) except Exception as e: - try: self.exception_queue.put(e) self.stream.close() diff --git a/daisy/tcp/tcp_message.py b/daisy/tcp/tcp_message.py index 818fc1c5..ab6829fd 100644 --- a/daisy/tcp/tcp_message.py +++ b/daisy/tcp/tcp_message.py @@ -1,5 +1,5 @@ class TCPMessage: - '''A message, to be sent between :class:`TCPServer` and :class:`TCPClient`. + """A message, to be sent between :class:`TCPServer` and :class:`TCPClient`. Args: @@ -12,8 +12,8 @@ class TCPMessage: The stream the message was received from. Will be set by :class:`TCPStream` and is ``None`` for messages that have not been sent. - ''' - def __init__(self, payload=None): + """ + def __init__(self, payload=None): self.payload = payload self.stream = None diff --git a/daisy/tcp/tcp_server.py b/daisy/tcp/tcp_server.py index a949cdf8..c06d6e95 100644 --- a/daisy/tcp/tcp_server.py +++ b/daisy/tcp/tcp_server.py @@ -1,7 +1,5 @@ from .exceptions import NoFreePort -from .internal_messages import ( - AckClientDisconnect, - NotifyClientDisconnect) +from .internal_messages import AckClientDisconnect, NotifyClientDisconnect from .io_looper import IOLooper from .tcp_stream import TCPStream import logging @@ -13,17 +11,16 @@ class TCPServer(tornado.tcpserver.TCPServer, IOLooper): - '''A TCP server to handle client-server communication through + """A TCP server to handle client-server communication through :class:`Message` objects. Args: max_port_tries (int, optional): How many times to try to find an empty random port. - ''' + """ def __init__(self, max_port_tries=1000): - tornado.tcpserver.TCPServer.__init__(self) IOLooper.__init__(self) @@ -33,23 +30,21 @@ def __init__(self, max_port_tries=1000): # find empty port, start listening for i in range(max_port_tries): - try: - self.listen(0) # 0 == random port break except OSError: - if i == self.max_port_tries - 1: raise NoFreePort( - "Could not find a free port after %d tries " % - self.max_port_tries) + "Could not find a free port after %d tries " + % self.max_port_tries + ) self.address = self._get_address() def get_message(self, timeout=None): - '''Get a message that was sent to this server. + """Get a message that was sent to this server. If the stream to any of the connected clients is closed, raises a :class:`StreamClosedError` for this client. Other TCP related @@ -62,20 +57,18 @@ def get_message(self, timeout=None): If set, wait up to `timeout` seconds for a message to arrive. If no message is available after the timeout, returns ``None``. If not set, wait until a message arrived. - ''' + """ self._check_for_errors() try: - return self.message_queue.get(block=True, timeout=timeout) except queue.Empty: - return None def disconnect(self): - '''Close all open streams to clients.''' + """Close all open streams to clients.""" streams = list(self.client_streams) # avoid set change error for stream in streams: @@ -83,7 +76,7 @@ def disconnect(self): stream.close() async def handle_stream(self, stream, address): - ''' Overrides a function from tornado's TCPServer, and is called + """Overrides a function from tornado's TCPServer, and is called whenever there is a new IOStream from an incoming connection (not whenever there is new data in the IOStream). @@ -96,7 +89,7 @@ async def handle_stream(self, stream, address): address (tuple): host, port that new connection comes from - ''' + """ logger.debug("Received new connection from %s:%d", *address) stream = TCPStream(stream, address) @@ -104,13 +97,10 @@ async def handle_stream(self, stream, address): self.client_streams.add(stream) while True: - try: - message = await stream._get_message() if isinstance(message, NotifyClientDisconnect): - # client notifies that it disconnects, send a response # indicating we are no longer using this stream and break # out of event loop @@ -119,11 +109,9 @@ async def handle_stream(self, stream, address): return else: - self.message_queue.put(message) except Exception as e: - try: self.exception_queue.put(e) finally: @@ -131,7 +119,6 @@ async def handle_stream(self, stream, address): return def _check_for_errors(self): - try: exception = self.exception_queue.get(block=False) raise exception @@ -139,7 +126,7 @@ def _check_for_errors(self): return def _get_address(self): - '''Get the host and port of the tcp server''' + """Get the host and port of the tcp server""" sock = self._sockets[list(self._sockets.keys())[0]] port = sock.getsockname()[1] diff --git a/daisy/tcp/tcp_stream.py b/daisy/tcp/tcp_stream.py index 375f2c48..ebfd12cb 100644 --- a/daisy/tcp/tcp_stream.py +++ b/daisy/tcp/tcp_stream.py @@ -9,7 +9,7 @@ class TCPStream(IOLooper): - '''Wrapper around :class:`tornado.iostream.IOStream` to send + """Wrapper around :class:`tornado.iostream.IOStream` to send :class:`TCPMessage` objects. Args: @@ -19,7 +19,7 @@ class TCPStream(IOLooper): address (tuple): The address the stream originates from. - ''' + """ def __init__(self, stream, address): super().__init__() @@ -27,7 +27,7 @@ def __init__(self, stream, address): self.address = address def send_message(self, message): - '''Send a message through this stream asynchronously. + """Send a message through this stream asynchronously. If the stream is closed, raises a :class:`StreamClosedError`. Successful return of this function does not guarantee that the message @@ -39,7 +39,7 @@ def send_message(self, message): message (:class:`daisy.TCPMessage`): Message to send over the stream. - ''' + """ if self.stream is None: raise StreamClosedError(*self.address) @@ -47,7 +47,7 @@ def send_message(self, message): self.ioloop.add_callback(self._send_message, message) def close(self): - '''Close this stream.''' + """Close this stream.""" try: self.stream.close() except Exception: @@ -56,50 +56,43 @@ def close(self): self.stream = None def closed(self): - '''True if this stream was closed.''' + """True if this stream was closed.""" if self.stream is None: return True return self.stream.closed() async def _send_message(self, message): - if self.stream is None: logger.error("No TCPStream available, can't send message.") pickled_data = pickle.dumps(message) - message_size_bytes = struct.pack('I', len(pickled_data)) + message_size_bytes = struct.pack("I", len(pickled_data)) message_bytes = message_size_bytes + pickled_data try: - await self.stream.write(message_bytes) except AttributeError: - # self.stream can be None even though we check earlier, due to race # conditions logger.error("No TCPStream available, can't send message.") pass except tornado.iostream.StreamClosedError: - logger.error("TCPStream lost connection while sending data.") self.stream = None async def _get_message(self): - if self.stream is None: raise StreamClosedError(*self.address) try: - size = await self.stream.read_bytes(4) - size = struct.unpack('I', size)[0] - assert (size < 65535) # TODO: parameterize max message size + size = struct.unpack("I", size)[0] + assert size < 65535 # TODO: parameterize max message size pickled_data = await self.stream.read_bytes(size) except tornado.iostream.StreamClosedError: - self.stream = None raise StreamClosedError(*self.address) @@ -108,5 +101,4 @@ async def _get_message(self): return message def __repr__(self): - return f"{self.address[0]}:{self.address[1]}" diff --git a/daisy/worker.py b/daisy/worker.py index 6c5cfa0f..f16d6355 100644 --- a/daisy/worker.py +++ b/daisy/worker.py @@ -8,8 +8,8 @@ logger = logging.getLogger(__name__) -class Worker(): - '''Create and start a worker, running a user-specified function in its own +class Worker: + """Create and start a worker, running a user-specified function in its own process. Args: @@ -24,9 +24,9 @@ class Worker(): If given, the context will be passed on to the worker via environment variables. - ''' + """ - __next_id = multiprocessing.Value('L') + __next_id = multiprocessing.Value("L") @staticmethod def get_next_id(): @@ -36,32 +36,30 @@ def get_next_id(): return next_id def __init__(self, spawn_function, context=None, error_queue=None): - self.spawn_function = spawn_function self.worker_id = Worker.get_next_id() if context is None: self.context = Context() else: self.context = context.copy() - self.context['worker_id'] = self.worker_id + self.context["worker_id"] = self.worker_id self.error_queue = error_queue self.process = None self.start() def start(self): - '''Start this worker. Note that workers are automatically started when - created. Use this function to re-start a stopped worker.''' + """Start this worker. Note that workers are automatically started when + created. Use this function to re-start a stopped worker.""" if self.process is not None: return - self.process = multiprocessing.Process( - target=lambda: self.__spawn_wrapper()) + self.process = multiprocessing.Process(target=lambda: self.__spawn_wrapper()) self.process.start() def stop(self): - '''Stop this worker.''' + """Stop this worker.""" if self.process is None: return @@ -76,35 +74,31 @@ def stop(self): self.process = None def __spawn_wrapper(self): - '''Thin wrapper around the user-specified spawn function to set - environment variables, redirect output, and to capture exceptions.''' + """Thin wrapper around the user-specified spawn function to set + environment variables, redirect output, and to capture exceptions.""" try: - os.environ[self.context.ENV_VARIABLE] = self.context.to_env() log_base = daisy_logging.get_worker_log_basename( - self.worker_id, - self.context.get('task_id', None)) + self.worker_id, self.context.get("task_id", None) + ) daisy_logging.redirect_stdouterr(log_base) self.spawn_function() except Exception as e: - logger.error("%s received exception: %s", self, e) if self.error_queue: try: self.error_queue.put(e, timeout=1) except queue.Full: logger.error( - "%s failed to forward exception, error queue is full", - self) + "%s failed to forward exception, error queue is full", self + ) except KeyboardInterrupt: - logger.debug("%s received ^C", self) def __repr__(self): - return "worker (%s)" % self.context diff --git a/daisy/worker_pool.py b/daisy/worker_pool.py index 5919cf8e..3e5702c2 100644 --- a/daisy/worker_pool.py +++ b/daisy/worker_pool.py @@ -9,7 +9,7 @@ class WorkerPool: - '''Manages a pool of workers in individual processes. All workers are + """Manages a pool of workers in individual processes. All workers are spawned by the same user-specified function. Args: @@ -27,10 +27,9 @@ class WorkerPool: A context to pass on to workers through environment variables. Will be augmented with ``worker_id``, a unique ID for each worker that is spawned by this pool. - ''' + """ def __init__(self, spawn_worker_function, context=None): - if context is None: context = Context() @@ -42,7 +41,7 @@ def __init__(self, spawn_worker_function, context=None): self.error_queue = multiprocessing.Queue(100) def set_num_workers(self, num_workers): - '''Set the number of workers in this pool. + """Set the number of workers in this pool. If higher than the current number of running workers, new workers will be spawned using ``spawn_worker_function``. @@ -56,12 +55,11 @@ def set_num_workers(self, num_workers): num_workers (int): The new number of workers for this pool. - ''' + """ logger.debug("setting number of workers to %d", num_workers) with self.workers_lock: - diff = num_workers - len(self.workers) logger.debug("current number of workers: %d", len(self.workers)) @@ -75,8 +73,8 @@ def inc_num_workers(self, num_workers): self.__start_workers(num_workers) def stop(self, worker_id=None): - '''Stop all current workers in this pool (``worker_id == None``) or a - specific worker.''' + """Stop all current workers in this pool (``worker_id == None``) or a + specific worker.""" if worker_id is None: self.set_num_workers(0) @@ -91,11 +89,11 @@ def stop(self, worker_id=None): del self.workers[worker_id] def check_for_errors(self): - '''If a worker fails with an exception, this exception will be queued + """If a worker fails with an exception, this exception will be queued in this pool to be propagated to the process that created the pool. Call this function periodically to check the queue and raise exceptions coming from the workers in the calling process. - ''' + """ try: error = self.error_queue.get(block=False) @@ -105,19 +103,14 @@ def check_for_errors(self): pass def __start_workers(self, n): - logger.debug("starting %d new workers", n) new_workers = [ Worker(self.spawn_function, self.context, self.error_queue) for _ in range(n) ] - self.workers.update({ - worker.worker_id: worker - for worker in new_workers - }) + self.workers.update({worker.worker_id: worker for worker in new_workers}) def __stop_workers(self, n): - logger.debug("stopping %d workers", n) sentenced_worker_ids = list(self.workers.keys())[-n:] diff --git a/docs/conf.py b/docs/conf.py index 59459d62..527a8b60 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,14 +19,14 @@ # -- Project information ----------------------------------------------------- -project = 'daisy' -copyright = '2019, Jan Funke, Tri Nguyen, Carolin Malin-Mayor, Arlo Sheridan, Philipp Hanslovsky, Chris Barnes' -author = 'Jan Funke, Tri Nguyen, Carolin Malin-Mayor, Arlo Sheridan, Philipp Hanslovsky, Chris Barnes' +project = "daisy" +copyright = "2019, Jan Funke, Tri Nguyen, Carolin Malin-Mayor, Arlo Sheridan, Philipp Hanslovsky, Chris Barnes" +author = "Jan Funke, Tri Nguyen, Carolin Malin-Mayor, Arlo Sheridan, Philipp Hanslovsky, Chris Barnes" # The short X.Y version -version = '' +version = "" # The full version, including alpha/beta/rc tags -release = 'v0.2' +release = "v0.2" # -- General configuration --------------------------------------------------- @@ -39,22 +39,22 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.mathjax', - 'sphinx.ext.viewcode', + "sphinx.ext.autodoc", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -66,7 +66,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The name of the Pygments (syntax highlighting) style to use. pygments_style = None @@ -77,7 +77,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -88,7 +88,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -104,7 +104,7 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'daisydoc' +htmlhelp_basename = "daisydoc" # -- Options for LaTeX output ------------------------------------------------ @@ -113,15 +113,12 @@ # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. # # 'preamble': '', - # Latex figure (float) alignment # # 'figure_align': 'htbp', @@ -131,8 +128,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'daisy.tex', 'daisy Documentation', - 'Jan Funke, Tri Nguyen, Carolin Malin-Mayor, Arlo Sheridan, Philipp Hanslovsky, Chris Barnes', 'manual'), + ( + master_doc, + "daisy.tex", + "daisy Documentation", + "Jan Funke, Tri Nguyen, Carolin Malin-Mayor, Arlo Sheridan, Philipp Hanslovsky, Chris Barnes", + "manual", + ), ] @@ -140,10 +142,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'daisy', 'daisy Documentation', - [author], 1) -] +man_pages = [(master_doc, "daisy", "daisy Documentation", [author], 1)] # -- Options for Texinfo output ---------------------------------------------- @@ -152,9 +151,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'daisy', 'daisy Documentation', - author, 'daisy', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "daisy", + "daisy Documentation", + author, + "daisy", + "One line description of project.", + "Miscellaneous", + ), ] @@ -173,7 +178,7 @@ # epub_uid = '' # A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] +epub_exclude_files = ["search.html"] # -- Extension configuration ------------------------------------------------- diff --git a/examples/basic_workflow.py b/examples/basic_workflow.py index 0958c31e..e659a2ec 100644 --- a/examples/basic_workflow.py +++ b/examples/basic_workflow.py @@ -11,29 +11,45 @@ def process_function(b): return 0 -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--total_roi_size', '-t', nargs='+', - help="Size of total region to process", - default=[100, 100]) - parser.add_argument('--block_read_size', '-r', nargs='+', - help="Size of block read region", - default=[20, 20]) - parser.add_argument('--block_write_size', '-w', nargs='+', - help="Size of block write region", - default=[16, 16]) - parser.add_argument('--num_workers', '-nw', type=int, - help="Number of processes to spawn", - default=1) - parser.add_argument('--read_write_conflict', '-rwc', action='store_true', - help="Flag to not schedule overlapping blocks" - " at the same time. Default is false") + parser.add_argument( + "--total_roi_size", + "-t", + nargs="+", + help="Size of total region to process", + default=[100, 100], + ) + parser.add_argument( + "--block_read_size", + "-r", + nargs="+", + help="Size of block read region", + default=[20, 20], + ) + parser.add_argument( + "--block_write_size", + "-w", + nargs="+", + help="Size of block write region", + default=[16, 16], + ) + parser.add_argument( + "--num_workers", "-nw", type=int, help="Number of processes to spawn", default=1 + ) + parser.add_argument( + "--read_write_conflict", + "-rwc", + action="store_true", + help="Flag to not schedule overlapping blocks" + " at the same time. Default is false", + ) args = parser.parse_args() ndims = len(args.total_roi_size) # define total region of interest (roi) - total_roi_start = daisy.Coordinate((0,)*ndims) + total_roi_start = daisy.Coordinate((0,) * ndims) total_roi_size = daisy.Coordinate(args.total_roi_size) total_roi = daisy.Roi(total_roi_start, total_roi_size) @@ -46,9 +62,10 @@ def process_function(b): # call run_blockwise daisy.run_blockwise( - total_roi, - block_read_roi, - block_write_roi, - process_function=process_function, - read_write_conflict=args.read_write_conflict, - num_workers=args.num_workers) + total_roi, + block_read_roi, + block_write_roi, + process_function=process_function, + read_write_conflict=args.read_write_conflict, + num_workers=args.num_workers, + ) diff --git a/examples/batch_task.py b/examples/batch_task.py index 9c80d14d..db424fb5 100644 --- a/examples/batch_task.py +++ b/examples/batch_task.py @@ -11,16 +11,16 @@ import daisy daisy_version = float(pkg_resources.get_distribution("daisy").version) -assert daisy_version >= 1, \ - f"This script was written for daisy v1.0 but current installed version " \ +assert daisy_version >= 1, ( + f"This script was written for daisy v1.0 but current installed version " f"is {daisy_version}" +) logging.basicConfig(level=logging.INFO) logger = logging.getLogger("BatchTask") -class Database(): - +class Database: def __init__(self, db_host, db_id, overwrite=False): self.table_name = "completion_db_col" @@ -28,7 +28,7 @@ def __init__(self, db_host, db_id, overwrite=False): # Use SQLite self.use_sql = True os.makedirs("daisy_db", exist_ok=True) - self.con = sqlite3.connect(f'daisy_db/{db_id}.db', check_same_thread=False) + self.con = sqlite3.connect(f"daisy_db/{db_id}.db", check_same_thread=False) self.cur = self.con.cursor() if overwrite: @@ -54,75 +54,88 @@ def __init__(self, db_host, db_id, overwrite=False): if self.table_name not in db.list_collection_names(): self.completion_db = db[self.table_name] self.completion_db.create_index( - [('block_id', pymongo.ASCENDING)], - name='block_id') + [("block_id", pymongo.ASCENDING)], name="block_id" + ) else: self.completion_db = db[self.table_name] def check(self, block_id): - if self.use_sql: - block_id = '_'.join([str(s) for s in block_id]) + block_id = "_".join([str(s) for s in block_id]) res = self.cur.execute( - f"SELECT * FROM {self.table_name} where block_id = '{block_id}'").fetchall() + f"SELECT * FROM {self.table_name} where block_id = '{block_id}'" + ).fetchall() if len(res): return True else: - if self.completion_db.count_documents({'block_id': block_id}) >= 1: + if self.completion_db.count_documents({"block_id": block_id}) >= 1: return True return False def add_finished(self, block_id): - if self.use_sql: - block_id = '_'.join([str(s) for s in block_id]) + block_id = "_".join([str(s) for s in block_id]) self.cur.execute(f"INSERT INTO {self.table_name} VALUES ('{block_id}')") self.con.commit() else: - document = { - 'block_id': block_id - } + document = {"block_id": block_id} self.completion_db.insert_one(document) -class BatchTask(): - '''Example base class for a batchable Daisy task. +class BatchTask: + """Example base class for a batchable Daisy task. This class takes care of some plumbing work such as creating a database to keep track of finished block and writing a config and batch file that users can submit jobs to a job system like SLURM. Derived tasks will only need to implement the code for computing the task. - ''' + """ @staticmethod def parse_args(ap): - try: ap.add_argument( - "--db_host", type=str, - help='MongoDB database host name. If `None` (default), use SQLite', - default=None) - # default='10.117.28.139') + "--db_host", + type=str, + help="MongoDB database host name. If `None` (default), use SQLite", + default=None, + ) + # default='10.117.28.139') ap.add_argument( - "--db_name", type=str, help='MongoDB database project name', - default=None) + "--db_name", + type=str, + help="MongoDB database project name", + default=None, + ) ap.add_argument( - "--overwrite", type=int, help='Whether to overwrite completed blocks', - default=0) + "--overwrite", + type=int, + help="Whether to overwrite completed blocks", + default=0, + ) ap.add_argument( - "--num_workers", type=int, help='Number of workers to run', - default=4) + "--num_workers", type=int, help="Number of workers to run", default=4 + ) ap.add_argument( - "--no_launch_workers", type=int, help='Whether to run workers automatically', - default=0) + "--no_launch_workers", + type=int, + help="Whether to run workers automatically", + default=0, + ) ap.add_argument( - "--config_hash", type=str, help='config string, used to keep track of progress', - default=None) + "--config_hash", + type=str, + help="config string, used to keep track of progress", + default=None, + ) ap.add_argument( - "--task_name", type=str, help='Name of task, default to class name', - default=None) + "--task_name", + type=str, + help="Name of task, default to class name", + default=None, + ) except argparse.ArgumentError as e: print("Current task has conflicting arguments with BatchTask!") raise e @@ -130,16 +143,15 @@ def parse_args(ap): return vars(ap.parse_args()) def __init__(self, config=None, config_file=None, task_id=None): - if config_file: print(f"Loading from config_file: {config_file}") - with open(config_file, 'r') as f: + with open(config_file, "r") as f: config = json.load(f) assert config is not None for key in config: - setattr(self, '%s' % key, config[key]) + setattr(self, "%s" % key, config[key]) if self.task_name is None: self.task_name = str(self.__class__.__name__) @@ -147,13 +159,18 @@ def __init__(self, config=None, config_file=None, task_id=None): self.__init_config = copy.deepcopy(config) if self.config_hash is None: - config_str = ''.join(['%s' % (v,) for k,v in config.items() - if k not in ['overwrite', 'num_workers', 'no_launch_workers']]) + config_str = "".join( + [ + "%s" % (v,) + for k, v in config.items() + if k not in ["overwrite", "num_workers", "no_launch_workers"] + ] + ) self.config_hash = str(hashlib.md5(config_str.encode()).hexdigest()) config_hash_short = self.config_hash[0:8] - self.db_id = '%s_%s' % (self.task_name, self.config_hash) - db_id_short = '%s_%s' % (self.task_name, config_hash_short) + self.db_id = "%s_%s" % (self.task_name, self.config_hash) + db_id_short = "%s_%s" % (self.task_name, config_hash_short) # if not given, we need to give the task a unique id for chaining if task_id is None: @@ -168,67 +185,69 @@ def _task_init(self, config): assert False, "Function needs to be implemented by subclass" def prepare_task(self): - '''Called by user to get a `daisy.Task`. It should call + """Called by user to get a `daisy.Task`. It should call `_write_config()` and return with a call to `_prepare_task()` Returns: `daisy.Task` object - ''' + """ assert False, "Function needs to be implemented by subclass" def _write_config(self, worker_filename, extra_config=None): - '''Make a config file for workers. Workers can then be run on the + """Make a config file for workers. Workers can then be run on the command line on potentially a different machine and use this file to initialize its variables. Args: extra_config (``dict``, optional): Any extra configs that should be written for workers - ''' + """ config = self.__init_config if extra_config: for k in extra_config: config[k] = extra_config[k] - self.config_file = os.path.join( - '.run_configs', '%s.config' % self.db_id) + self.config_file = os.path.join(".run_configs", "%s.config" % self.db_id) - self.new_actor_cmd = 'python %s run_worker %s' % ( - worker_filename, self.config_file) + self.new_actor_cmd = "python %s run_worker %s" % ( + worker_filename, + self.config_file, + ) if self.db_name is None: - self.db_name = '%s' % self.db_id + self.db_name = "%s" % self.db_id - config['db_id'] = self.db_id + config["db_id"] = self.db_id - os.makedirs('.run_configs', exist_ok=True) - with open(self.config_file, 'w') as f: + os.makedirs(".run_configs", exist_ok=True) + with open(self.config_file, "w") as f: json.dump(config, f) # write batch script - self.sbatch_file = os.path.join('.run_configs', '%s.sh' % self.db_id) + self.sbatch_file = os.path.join(".run_configs", "%s.sh" % self.db_id) self.generate_batch_script( self.sbatch_file, self.new_actor_cmd, - log_dir='.logs', + log_dir=".logs", logname=self.db_id, - ) + ) self.write_config_called = True def generate_batch_script( - self, - output_script, - run_cmd, - log_dir, - logname, - cpu_time=11, - queue='short', - cpu_cores=1, - cpu_mem=2, - gpu=None): - '''Example SLURM script.''' + self, + output_script, + run_cmd, + log_dir, + logname, + cpu_time=11, + queue="short", + cpu_cores=1, + cpu_mem=2, + gpu=None, + ): + """Example SLURM script.""" text = [] text.append("#!/bin/bash") @@ -236,7 +255,7 @@ def generate_batch_script( if gpu is not None: text.append("#SBATCH -p gpu") - if gpu == '' or gpu == 'any': + if gpu == "" or gpu == "any": text.append("#SBATCH --gres=gpu:1") else: text.append("#SBATCH --gres=gpu:{}:1".format(gpu)) @@ -250,25 +269,25 @@ def generate_batch_script( text.append("") text.append(run_cmd) - with open(output_script, 'w') as f: - f.write('\n'.join(text)) + with open(output_script, "w") as f: + f.write("\n".join(text)) def _prepare_task( - self, - total_roi, - read_roi, - write_roi, - check_fn=None, - fit='shrink', - read_write_conflict=False, - upstream_tasks=None, - ): - - assert self.write_config_called, ( - "`BatchTask._write_config()` was not called") - - print("Processing total_roi %s with read_roi %s and write_roi %s" % ( - total_roi, read_roi, write_roi)) + self, + total_roi, + read_roi, + write_roi, + check_fn=None, + fit="shrink", + read_write_conflict=False, + upstream_tasks=None, + ): + assert self.write_config_called, "`BatchTask._write_config()` was not called" + + print( + "Processing total_roi %s with read_roi %s and write_roi %s" + % (total_roi, read_roi, write_roi) + ) if check_fn is None: check_fn = lambda b: self._default_check_fn(b) @@ -305,7 +324,7 @@ def _prepare_task( ) def _default_check_fn(self, block): - '''The default check function uses database for checking completion''' + """The default check function uses database for checking completion""" if self.overwrite: return False @@ -313,20 +332,20 @@ def _default_check_fn(self, block): return self.database.check(block.block_id) def _worker_impl(self, args): - '''Worker function implementation''' + """Worker function implementation""" assert False, "Function needs to be implemented by subclass" def _new_worker(self): - if not self.no_launch_workers: self.run_worker() def run_worker(self): - '''Wrapper for `_worker_impl()`''' + """Wrapper for `_worker_impl()`""" - assert 'DAISY_CONTEXT' in os.environ, ( - "DAISY_CONTEXT must be defined as an environment variable") - logger.info("WORKER: Running with context %s" % os.environ['DAISY_CONTEXT']) + assert ( + "DAISY_CONTEXT" in os.environ + ), "DAISY_CONTEXT must be defined as an environment variable" + logger.info("WORKER: Running with context %s" % os.environ["DAISY_CONTEXT"]) database = Database(self.db_host, self.db_id) client_scheduler = daisy.Client() @@ -335,14 +354,19 @@ def run_worker(self): with client_scheduler.acquire_block() as block: if block is None: break - logger.info(f'Received block {block}') + logger.info(f"Received block {block}") self._worker_impl(block) database.add_finished(block.block_id) def init_callback_fn(self, context): + print( + "sbatch command: DAISY_CONTEXT={} sbatch --parsable {}\n".format( + context.to_env(), self.sbatch_file + ) + ) - print("sbatch command: DAISY_CONTEXT={} sbatch --parsable {}\n".format( - context.to_env(), self.sbatch_file)) - - print("Terminal command: DAISY_CONTEXT={} {}\n".format( - context.to_env(), self.new_actor_cmd)) + print( + "Terminal command: DAISY_CONTEXT={} {}\n".format( + context.to_env(), self.new_actor_cmd + ) + ) diff --git a/examples/chaining_example.py b/examples/chaining_example.py index 1c8abaa6..87396b5a 100644 --- a/examples/chaining_example.py +++ b/examples/chaining_example.py @@ -5,37 +5,48 @@ import daisy from gaussian_smoothing2 import GaussianSmoothingTask -if __name__ == '__main__': - +if __name__ == "__main__": ap = argparse.ArgumentParser() - ap.add_argument("in_file", type=str, help='The input container') - ap.add_argument("in_ds_name", type=str, help='The name of the dataset') - ap.add_argument("--out_file", type=str, default=None, - help='The output container, defaults to be the same as in_file') - ap.add_argument('--sigma', '-s', type=float, - help="Sigma to use for gaussian filter", - default=2) - ap.add_argument('--block_read_size', '-r', nargs='+', - help="Size of block read region", - default=[20, 200, 200]) - ap.add_argument('--block_write_size', '-w', nargs='+', - help="Size of block write region", - default=[18, 180, 180]) + ap.add_argument("in_file", type=str, help="The input container") + ap.add_argument("in_ds_name", type=str, help="The name of the dataset") + ap.add_argument( + "--out_file", + type=str, + default=None, + help="The output container, defaults to be the same as in_file", + ) + ap.add_argument( + "--sigma", "-s", type=float, help="Sigma to use for gaussian filter", default=2 + ) + ap.add_argument( + "--block_read_size", + "-r", + nargs="+", + help="Size of block read region", + default=[20, 200, 200], + ) + ap.add_argument( + "--block_write_size", + "-w", + nargs="+", + help="Size of block write region", + default=[18, 180, 180], + ) config = GaussianSmoothingTask.parse_args(ap) config1 = copy.deepcopy(config) - config1['out_ds_name'] = 'volumes/raw_smoothed' - daisy_task1 = GaussianSmoothingTask( - config1, task_id='Gaussian1').prepare_task() + config1["out_ds_name"] = "volumes/raw_smoothed" + daisy_task1 = GaussianSmoothingTask(config1, task_id="Gaussian1").prepare_task() # here we reuse parameters but set the output dataset of the previous # task as input config2 = copy.deepcopy(config) - config2['in_ds_name'] = 'volumes/raw_smoothed' - config2['out_ds_name'] = 'volumes/raw_smoothed_smoothed' - daisy_task2 = GaussianSmoothingTask( - config2, task_id='Gaussian2').prepare_task(upstream_tasks=[daisy_task1]) + config2["in_ds_name"] = "volumes/raw_smoothed" + config2["out_ds_name"] = "volumes/raw_smoothed_smoothed" + daisy_task2 = GaussianSmoothingTask(config2, task_id="Gaussian2").prepare_task( + upstream_tasks=[daisy_task1] + ) done = daisy.run_blockwise([daisy_task1, daisy_task2]) diff --git a/examples/gaussian_smoothing1.py b/examples/gaussian_smoothing1.py index 21dea4f1..005bf38c 100644 --- a/examples/gaussian_smoothing1.py +++ b/examples/gaussian_smoothing1.py @@ -9,8 +9,9 @@ from batch_task import BatchTask import logging + logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('GaussianSmoothingTask') +logger = logging.getLogger("GaussianSmoothingTask") def smooth(block, dataset, output, sigma=5): @@ -22,39 +23,51 @@ def smooth(block, dataset, output, sigma=5): logger.debug("Got data of shape %s" % str(data.shape)) # apply gaussian filter - r = scipy.ndimage.gaussian_filter( - data, sigma=sigma, mode='constant') + r = scipy.ndimage.gaussian_filter(data, sigma=sigma, mode="constant") # write result to output dataset in block.write_roi - to_write = daisy.Array( - data=r, - roi=block.read_roi, - voxel_size=dataset.voxel_size) + to_write = daisy.Array(data=r, roi=block.read_roi, voxel_size=dataset.voxel_size) output[block.write_roi] = to_write[block.write_roi] logger.debug("Done") return 0 -if __name__ == '__main__': - +if __name__ == "__main__": ap = argparse.ArgumentParser() - ap.add_argument("in_file", type=str, help='The input container') - ap.add_argument("in_ds_name", type=str, help='The name of the dataset') - ap.add_argument("--out_file", type=str, default=None, - help='The output container, defaults to be the same as in_file') - ap.add_argument("--out_ds_name", type=str, default=None, - help='The name of the dataset, defaults to be in_ds_name + smoothed') - ap.add_argument('--sigma', '-s', type=float, - help="Sigma to use for gaussian filter", - default=2) - ap.add_argument('--block_read_size', '-r', nargs='+', - help="Size of block read region", - default=[20, 200, 200]) - ap.add_argument('--block_write_size', '-w', nargs='+', - help="Size of block write region", - default=[18, 180, 180]) - ap.add_argument("--num_workers", type=int, help='Number of workers to run', - default=4) + ap.add_argument("in_file", type=str, help="The input container") + ap.add_argument("in_ds_name", type=str, help="The name of the dataset") + ap.add_argument( + "--out_file", + type=str, + default=None, + help="The output container, defaults to be the same as in_file", + ) + ap.add_argument( + "--out_ds_name", + type=str, + default=None, + help="The name of the dataset, defaults to be in_ds_name + smoothed", + ) + ap.add_argument( + "--sigma", "-s", type=float, help="Sigma to use for gaussian filter", default=2 + ) + ap.add_argument( + "--block_read_size", + "-r", + nargs="+", + help="Size of block read region", + default=[20, 200, 200], + ) + ap.add_argument( + "--block_write_size", + "-w", + nargs="+", + help="Size of block write region", + default=[18, 180, 180], + ) + ap.add_argument( + "--num_workers", type=int, help="Number of workers to run", default=4 + ) config = ap.parse_args() @@ -66,16 +79,18 @@ def smooth(block, dataset, output, sigma=5): ndims = len(total_roi.get_offset()) # define block read and write rois - assert len(config.block_read_size) == ndims,\ - "Read size must have same dimensions as in_file" - assert len(config.block_write_size) == ndims,\ - "Write size must have same dimensions as in_file" + assert ( + len(config.block_read_size) == ndims + ), "Read size must have same dimensions as in_file" + assert ( + len(config.block_write_size) == ndims + ), "Write size must have same dimensions as in_file" block_read_size = daisy.Coordinate(config.block_read_size) block_write_size = daisy.Coordinate(config.block_write_size) block_read_size *= dataset.voxel_size block_write_size *= dataset.voxel_size context = (block_read_size - block_write_size) / 2 - block_read_roi = daisy.Roi((0,)*ndims, block_read_size) + block_read_roi = daisy.Roi((0,) * ndims, block_read_size) block_write_roi = daisy.Roi(context, block_write_size) # prepare output dataset @@ -83,30 +98,32 @@ def smooth(block, dataset, output, sigma=5): if config.out_file is None: config.out_file = config.in_file if config.out_ds_name is None: - config.out_ds_name = config.in_ds_name + '_smoothed' + config.out_ds_name = config.in_ds_name + "_smoothed" - logger.info(f'Processing data to {config.out_file}/{config.out_ds_name}') + logger.info(f"Processing data to {config.out_file}/{config.out_ds_name}") output_dataset = daisy.prepare_ds( - config.out_file, - config.out_ds_name, - total_roi=output_roi, - voxel_size=dataset.voxel_size, - dtype=dataset.dtype, - write_size=block_write_roi.get_shape()) + config.out_file, + config.out_ds_name, + total_roi=output_roi, + voxel_size=dataset.voxel_size, + dtype=dataset.dtype, + write_size=block_write_roi.get_shape(), + ) # make task task = daisy.Task( - 'GaussianSmoothingTask', - total_roi, - block_read_roi, - block_write_roi, - process_function=lambda b: smooth( - b, dataset, output_dataset, sigma=config.sigma), - read_write_conflict=False, - num_workers=config.num_workers, - fit='shrink' - ) + "GaussianSmoothingTask", + total_roi, + block_read_roi, + block_write_roi, + process_function=lambda b: smooth( + b, dataset, output_dataset, sigma=config.sigma + ), + read_write_conflict=False, + num_workers=config.num_workers, + fit="shrink", + ) # run task ret = daisy.run_blockwise([task]) @@ -114,4 +131,4 @@ def smooth(block, dataset, output, sigma=5): if ret: logger.info("Ran all blocks successfully!") else: - logger.info("Did not run all blocks successfully...") \ No newline at end of file + logger.info("Did not run all blocks successfully...") diff --git a/examples/gaussian_smoothing2.py b/examples/gaussian_smoothing2.py index 5f42583c..fe257983 100644 --- a/examples/gaussian_smoothing2.py +++ b/examples/gaussian_smoothing2.py @@ -9,8 +9,9 @@ from batch_task import BatchTask import logging + logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('GaussianSmoothingTask') +logger = logging.getLogger("GaussianSmoothingTask") def smooth(block, dataset, output, sigma=5): @@ -22,23 +23,17 @@ def smooth(block, dataset, output, sigma=5): logger.debug("Got data of shape %s" % str(data.shape)) # apply gaussian filter - r = scipy.ndimage.gaussian_filter( - data, sigma=sigma, mode='constant') + r = scipy.ndimage.gaussian_filter(data, sigma=sigma, mode="constant") # write result to output dataset in block.write_roi - to_write = daisy.Array( - data=r, - roi=block.read_roi, - voxel_size=dataset.voxel_size) + to_write = daisy.Array(data=r, roi=block.read_roi, voxel_size=dataset.voxel_size) output[block.write_roi] = to_write[block.write_roi] logger.debug("Done") return 0 class GaussianSmoothingTask(BatchTask): - def _task_init(self): - # open dataset dataset = daisy.open_ds(self.in_file, self.in_ds_name) @@ -47,16 +42,18 @@ def _task_init(self): ndims = len(total_roi.get_offset()) # define block read and write rois - assert len(self.block_read_size) == ndims,\ - "Read size must have same dimensions as in_file" - assert len(self.block_write_size) == ndims,\ - "Write size must have same dimensions as in_file" + assert ( + len(self.block_read_size) == ndims + ), "Read size must have same dimensions as in_file" + assert ( + len(self.block_write_size) == ndims + ), "Write size must have same dimensions as in_file" block_read_size = daisy.Coordinate(self.block_read_size) block_write_size = daisy.Coordinate(self.block_write_size) block_read_size *= dataset.voxel_size block_write_size *= dataset.voxel_size context = (block_read_size - block_write_size) / 2 - block_read_roi = daisy.Roi((0,)*ndims, block_read_size) + block_read_roi = daisy.Roi((0,) * ndims, block_read_size) block_write_roi = daisy.Roi(context, block_write_size) # prepare output dataset @@ -64,17 +61,18 @@ def _task_init(self): if self.out_file is None: self.out_file = self.in_file if self.out_ds_name is None: - self.out_ds_name = self.in_ds_name + '_smoothed' + self.out_ds_name = self.in_ds_name + "_smoothed" - logger.info(f'Processing data to {self.out_file}/{self.out_ds_name}') + logger.info(f"Processing data to {self.out_file}/{self.out_ds_name}") output_dataset = daisy.prepare_ds( - self.out_file, - self.out_ds_name, - total_roi=output_roi, - voxel_size=dataset.voxel_size, - dtype=dataset.dtype, - write_size=block_write_roi.get_shape()) + self.out_file, + self.out_ds_name, + total_roi=output_roi, + voxel_size=dataset.voxel_size, + dtype=dataset.dtype, + write_size=block_write_roi.get_shape(), + ) # save variables for other functions self.total_roi = total_roi @@ -88,42 +86,60 @@ def prepare_task(self, upstream_tasks=None): worker_filename = os.path.realpath(__file__) self._write_config(worker_filename) return self._prepare_task( - self.total_roi, - self.block_read_roi, - self.block_write_roi, - read_write_conflict=False, - upstream_tasks=upstream_tasks, - ) + self.total_roi, + self.block_read_roi, + self.block_write_roi, + read_write_conflict=False, + upstream_tasks=upstream_tasks, + ) def _worker_impl(self, block): - '''Worker function implementation''' + """Worker function implementation""" smooth(block, self.dataset, self.output_dataset, sigma=self.sigma) -if __name__ == '__main__': - - if len(sys.argv) > 1 and sys.argv[1] == 'run_worker': +if __name__ == "__main__": + if len(sys.argv) > 1 and sys.argv[1] == "run_worker": task = GaussianSmoothingTask(config_file=sys.argv[2]) task.run_worker() else: - ap = argparse.ArgumentParser() - ap.add_argument("in_file", type=str, help='The input container') - ap.add_argument("in_ds_name", type=str, help='The name of the dataset') - ap.add_argument("--out_file", type=str, default=None, - help='The output container, defaults to be the same as in_file') - ap.add_argument("--out_ds_name", type=str, default=None, - help='The name of the dataset, defaults to be in_ds_name + smoothed') - ap.add_argument('--sigma', '-s', type=float, - help="Sigma to use for gaussian filter", - default=2) - ap.add_argument('--block_read_size', '-r', nargs='+', - help="Size of block read region", - default=[20, 200, 200]) - ap.add_argument('--block_write_size', '-w', nargs='+', - help="Size of block write region", - default=[18, 180, 180]) + ap.add_argument("in_file", type=str, help="The input container") + ap.add_argument("in_ds_name", type=str, help="The name of the dataset") + ap.add_argument( + "--out_file", + type=str, + default=None, + help="The output container, defaults to be the same as in_file", + ) + ap.add_argument( + "--out_ds_name", + type=str, + default=None, + help="The name of the dataset, defaults to be in_ds_name + smoothed", + ) + ap.add_argument( + "--sigma", + "-s", + type=float, + help="Sigma to use for gaussian filter", + default=2, + ) + ap.add_argument( + "--block_read_size", + "-r", + nargs="+", + help="Size of block read region", + default=[20, 200, 200], + ) + ap.add_argument( + "--block_write_size", + "-w", + nargs="+", + help="Size of block write region", + default=[18, 180, 180], + ) config = GaussianSmoothingTask.parse_args(ap) task = GaussianSmoothingTask(config) @@ -132,4 +148,4 @@ def _worker_impl(self, block): if done: logger.info("Ran all blocks successfully!") else: - logger.info("Did not run all blocks successfully...") \ No newline at end of file + logger.info("Did not run all blocks successfully...") diff --git a/examples/hdf_to_zarr.py b/examples/hdf_to_zarr.py index 36c12e77..b512158a 100644 --- a/examples/hdf_to_zarr.py +++ b/examples/hdf_to_zarr.py @@ -14,7 +14,6 @@ def calculateNearIsotropicDimensions(voxel_size, max_voxel_count): - dims = len(voxel_size) voxel_count = 1 @@ -34,9 +33,7 @@ def calculateNearIsotropicDimensions(voxel_size, max_voxel_count): class HDF2ZarrTask(BatchTask): - def _task_init(self): - logger.info(f"Accessing {self.in_ds_name} in {self.in_file}") try: self.in_ds = daisy.open_ds(self.in_file, self.in_ds_name) @@ -51,35 +48,30 @@ def _task_init(self): elif self.in_ds.n_channel_dims == 1: num_channels = self.in_ds.shape[0] else: - raise RuntimeError( - "more than one channel not yet implemented, sorry...") + raise RuntimeError("more than one channel not yet implemented, sorry...") self.ds_roi = self.in_ds.roi sub_roi = None if self.roi_offset is not None or self.roi_shape is not None: assert self.roi_offset is not None and self.roi_shape is not None - self.schedule_roi = daisy.Roi( - tuple(self.roi_offset), tuple(self.roi_shape)) + self.schedule_roi = daisy.Roi(tuple(self.roi_offset), tuple(self.roi_shape)) sub_roi = self.schedule_roi else: self.schedule_roi = self.in_ds.roi if self.chunk_shape_voxel is None: self.chunk_shape_voxel = calculateNearIsotropicDimensions( - voxel_size, self.max_voxel_count) + voxel_size, self.max_voxel_count + ) logger.info(voxel_size) logger.info(self.chunk_shape_voxel) self.chunk_shape_voxel = Coordinate(self.chunk_shape_voxel) - self.schedule_roi = self.schedule_roi.snap_to_grid( - voxel_size, - mode='grow') - out_ds_roi = self.ds_roi.snap_to_grid( - voxel_size, - mode='grow') + self.schedule_roi = self.schedule_roi.snap_to_grid(voxel_size, mode="grow") + out_ds_roi = self.ds_roi.snap_to_grid(voxel_size, mode="grow") - self.write_size = self.chunk_shape_voxel*voxel_size + self.write_size = self.chunk_shape_voxel * voxel_size scheduling_block_size = self.write_size self.write_roi = daisy.Roi((0, 0, 0), scheduling_block_size) @@ -88,9 +80,9 @@ def _task_init(self): # with sub_roi, the coordinates are absolute # so we'd need to align total_roi to the write size too self.schedule_roi = self.schedule_roi.snap_to_grid( - self.write_size, mode='grow') - out_ds_roi = out_ds_roi.snap_to_grid( - self.write_size, mode='grow') + self.write_size, mode="grow" + ) + out_ds_roi = out_ds_roi.snap_to_grid(self.write_size, mode="grow") logger.info(f"out_ds_roi: {out_ds_roi}") logger.info(f"schedule_roi: {self.schedule_roi}") @@ -98,7 +90,7 @@ def _task_init(self): logger.info(f"voxel_size: {voxel_size}") if self.out_file is None: - self.out_file = '.'.join(self.in_file.split('.')[0:-1])+'.zarr' + self.out_file = ".".join(self.in_file.split(".")[0:-1]) + ".zarr" if self.out_ds_name is None: self.out_ds_name = self.in_ds_name @@ -113,17 +105,25 @@ def _task_init(self): dtype=self.in_ds.dtype, num_channels=num_channels, force_exact_write_size=True, - compressor={'id': 'blosc', 'clevel': 3}, + compressor={"id": "blosc", "clevel": 3}, delete=delete, - ) + ) def prepare_task(self): - assert len(self.chunk_shape_voxel) == 3 logger.info( "Rechunking %s/%s to %s/%s with chunk_shape_voxel %s (write_size %s, scheduling %s)" - % (self.in_file, self.in_ds_name, self.out_file, self.out_ds_name, self.chunk_shape_voxel, self.write_size, self.write_roi)) + % ( + self.in_file, + self.in_ds_name, + self.out_file, + self.out_ds_name, + self.chunk_shape_voxel, + self.write_size, + self.write_roi, + ) + ) logger.info("ROI: %s" % self.schedule_roi) worker_filename = os.path.realpath(__file__) @@ -134,14 +134,13 @@ def prepare_task(self): read_roi=self.write_roi, write_roi=self.write_roi, check_fn=lambda b: self.check_fn(b), - ) + ) def _worker_impl(self, block): - '''Worker function implementation''' + """Worker function implementation""" self.out_ds[block.write_roi] = self.in_ds[block.write_roi] def check_fn(self, block): - write_roi = self.out_ds.roi.intersect(block.write_roi) if write_roi.empty: return True @@ -150,39 +149,42 @@ def check_fn(self, block): if __name__ == "__main__": - - if len(sys.argv) > 1 and sys.argv[1] == 'run_worker': + if len(sys.argv) > 1 and sys.argv[1] == "run_worker": task = HDF2ZarrTask(config_file=sys.argv[2]) task.run_worker() else: - ap = argparse.ArgumentParser( - description="Create a zarr/N5 container from hdf.") - ap.add_argument("in_file", type=str, help='The input container') - ap.add_argument("in_ds_name", type=str, help='The name of the dataset') - ap.add_argument( - "--out_file", type=str, default=None, - help='The output container, defaults to be the same as in_file+.zarr' - ) + ap = argparse.ArgumentParser(description="Create a zarr/N5 container from hdf.") + ap.add_argument("in_file", type=str, help="The input container") + ap.add_argument("in_ds_name", type=str, help="The name of the dataset") ap.add_argument( - "--out_ds_name", type=str, default=None, - help='The name of the dataset, defaults to be in_ds_name' - ) - ap.add_argument( - "--chunk_shape_voxel", type=int, help='The size of a chunk in voxels', - nargs='+', default=None - ) + "--out_file", + type=str, + default=None, + help="The output container, defaults to be the same as in_file+.zarr", + ) ap.add_argument( - "--max_voxel_count", type=int, default=256*1024, - help='If chunk_shape_voxel is not given, use this value to calculate' - 'a near isotropic chunk shape', - ) + "--out_ds_name", + type=str, + default=None, + help="The name of the dataset, defaults to be in_ds_name", + ) ap.add_argument( - "--roi_offset", type=int, help='', - nargs='+', default=None) + "--chunk_shape_voxel", + type=int, + help="The size of a chunk in voxels", + nargs="+", + default=None, + ) ap.add_argument( - "--roi_shape", type=int, help='', - nargs='+', default=None) + "--max_voxel_count", + type=int, + default=256 * 1024, + help="If chunk_shape_voxel is not given, use this value to calculate" + "a near isotropic chunk shape", + ) + ap.add_argument("--roi_offset", type=int, help="", nargs="+", default=None) + ap.add_argument("--roi_shape", type=int, help="", nargs="+", default=None) config = HDF2ZarrTask.parse_args(ap) task = HDF2ZarrTask(config) diff --git a/examples/minimal_example/server.py b/examples/minimal_example/server.py index 1ca9014c..314c6fd0 100644 --- a/examples/minimal_example/server.py +++ b/examples/minimal_example/server.py @@ -1,7 +1,8 @@ import multiprocessing + # workaround for MacOS: # this needs to be set before importing any library that uses multiprocessing -multiprocessing.set_start_method('fork') +multiprocessing.set_start_method("fork") import logging import daisy import subprocess @@ -10,13 +11,7 @@ logging.basicConfig(level=logging.INFO) -def run_scheduler( - total_roi, - read_roi, - write_roi, - num_workers -): - +def run_scheduler(total_roi, read_roi, write_roi, num_workers): dummy_task = daisy.Task( "dummy_task", total_roi=total_roi, @@ -32,7 +27,6 @@ def run_scheduler( def start_worker(): - worker_id = daisy.Context.from_env()["worker_id"] task_id = daisy.Context.from_env()["task_id"] @@ -46,14 +40,8 @@ def start_worker(): if __name__ == "__main__": - total_roi = daisy.Roi((0, 0, 0), (1024, 1024, 1024)) write_roi = daisy.Roi((0, 0, 0), (256, 256, 256)) read_roi = write_roi.grow((16, 16, 16), (16, 16, 16)) - run_scheduler( - total_roi, - read_roi, - write_roi, - num_workers=10 - ) + run_scheduler(total_roi, read_roi, write_roi, num_workers=10) diff --git a/examples/minimal_example/worker.py b/examples/minimal_example/worker.py index f89280a5..606ec2ca 100644 --- a/examples/minimal_example/worker.py +++ b/examples/minimal_example/worker.py @@ -7,13 +7,11 @@ def test_worker(): - client = daisy.Client() while True: logger.info("getting block") with client.acquire_block() as block: - if block is None: break @@ -26,5 +24,4 @@ def test_worker(): if __name__ == "__main__": - test_worker() diff --git a/examples/visualize.py b/examples/visualize.py index f782ab57..c64eaeef 100644 --- a/examples/visualize.py +++ b/examples/visualize.py @@ -12,95 +12,96 @@ parser = argparse.ArgumentParser() parser.add_argument( - '--file', - '-f', - type=str, - action='append', - help="The path to the container to show") + "--file", "-f", type=str, action="append", help="The path to the container to show" +) parser.add_argument( - '--datasets', - '-d', + "--datasets", + "-d", type=str, - nargs='+', - action='append', - help="The datasets in the container to show") + nargs="+", + action="append", + help="The datasets in the container to show", +) parser.add_argument( - '--graphs', - '-g', + "--graphs", + "-g", type=str, - nargs='+', - action='append', - help="The graphs in the container to show") + nargs="+", + action="append", + help="The graphs in the container to show", +) parser.add_argument( - '--no-browser', - '-n', + "--no-browser", + "-n", type=bool, - nargs='?', + nargs="?", default=False, const=True, - help="If set, do not open a browser, just print a URL") + help="If set, do not open a browser, just print a URL", +) args = parser.parse_args() -neuroglancer.set_server_bind_address('0.0.0.0') +neuroglancer.set_server_bind_address("0.0.0.0") viewer = neuroglancer.Viewer() -def to_slice(slice_str): - values = [int(x) for x in slice_str.split(':')] +def to_slice(slice_str): + values = [int(x) for x in slice_str.split(":")] if len(values) == 1: return values[0] return slice(*values) -def parse_ds_name(ds): - tokens = ds.split('[') +def parse_ds_name(ds): + tokens = ds.split("[") if len(tokens) == 1: return ds, None ds, slices = tokens - slices = list(map(to_slice, slices.rstrip(']').split(','))) + slices = list(map(to_slice, slices.rstrip("]").split(","))) return ds, slices -class Project: +class Project: def __init__(self, array, dim, value): self.array = array self.dim = dim self.value = value - self.shape = array.shape[:self.dim] + array.shape[self.dim + 1:] + self.shape = array.shape[: self.dim] + array.shape[self.dim + 1 :] self.dtype = array.dtype def __getitem__(self, key): - slices = key[:self.dim] + (self.value,) + key[self.dim:] + slices = key[: self.dim] + (self.value,) + key[self.dim :] ret = self.array[slices] return ret -def slice_dataset(a, slices): +def slice_dataset(a, slices): dims = a.roi.dims for d, s in list(enumerate(slices))[::-1]: - if isinstance(s, slice): raise NotImplementedError("Slicing not yet implemented!") else: - index = (s - a.roi.get_begin()[d])//a.voxel_size[d] + index = (s - a.roi.get_begin()[d]) // a.voxel_size[d] a.data = Project(a.data, d, index) a.roi = daisy.Roi( - a.roi.get_begin()[:d] + a.roi.get_begin()[d + 1:], - a.roi.get_shape()[:d] + a.roi.get_shape()[d + 1:]) - a.voxel_size = a.voxel_size[:d] + a.voxel_size[d + 1:] + a.roi.get_begin()[:d] + a.roi.get_begin()[d + 1 :], + a.roi.get_shape()[:d] + a.roi.get_shape()[d + 1 :], + ) + a.voxel_size = a.voxel_size[:d] + a.voxel_size[d + 1 :] return a + def open_dataset(f, ds): original_ds = ds ds, slices = parse_ds_name(ds) - slices_str = original_ds[len(ds):] + slices_str = original_ds[len(ds) :] try: dataset_as = [] @@ -129,7 +130,9 @@ def open_dataset(f, ds): if a.roi.dims == 2: print("ROI is 2D, recruiting next channel to z dimension") - a.roi = daisy.Roi((0,) + a.roi.get_begin(), (a.shape[-3],) + a.roi.get_shape()) + a.roi = daisy.Roi( + (0,) + a.roi.get_begin(), (a.shape[-3],) + a.roi.get_shape() + ) a.voxel_size = daisy.Coordinate((1,) + a.voxel_size) if a.roi.dims == 4: @@ -143,34 +146,28 @@ def open_dataset(f, ds): return [(a, ds)] else: - return [([daisy.open_ds(f, f"{ds}/{key}") for key in zarr.open(f)[ds].keys()], ds)] + return [ + ([daisy.open_ds(f, f"{ds}/{key}") for key in zarr.open(f)[ds].keys()], ds) + ] -for f, datasets in zip(args.file, args.datasets): +for f, datasets in zip(args.file, args.datasets): arrays = [] for ds in datasets: try: - print("Adding %s, %s" % (f, ds)) dataset_as = open_dataset(f, ds) except Exception as e: - print(type(e), e) print("Didn't work, checking if this is multi-res...") - scales = glob.glob(os.path.join(f, ds, 's*')) + scales = glob.glob(os.path.join(f, ds, "s*")) if len(scales) == 0: print(f"Couldn't read {ds}, skipping...") raise e - print("Found scales %s" % ([ - os.path.relpath(s, f) - for s in scales - ],)) - a = [ - open_dataset(f, os.path.relpath(scale_ds, f)) - for scale_ds in scales - ] + print("Found scales %s" % ([os.path.relpath(s, f) for s in scales],)) + a = [open_dataset(f, os.path.relpath(scale_ds, f)) for scale_ds in scales] for a in dataset_as: arrays.append(a) @@ -180,13 +177,11 @@ def open_dataset(f, ds): if args.graphs: for f, graphs in zip(args.file, args.graphs): - for graph in graphs: - graph_annotations = [] try: - ids = daisy.open_ds(f, graph + '-ids').data - loc = daisy.open_ds(f, graph + '-locations').data + ids = daisy.open_ds(f, graph + "-ids").data + loc = daisy.open_ds(f, graph + "-locations").data except: loc = daisy.open_ds(f, graph).data ids = None @@ -199,15 +194,15 @@ def open_dataset(f, ds): l = np.concatenate([[0], l]) graph_annotations.append( neuroglancer.EllipsoidAnnotation( - center=l[::-1], - radii=(5, 5, 5), - id=i)) + center=l[::-1], radii=(5, 5, 5), id=i + ) + ) graph_layer = neuroglancer.AnnotationLayer( - annotations=graph_annotations, - voxel_size=(1, 1, 1)) + annotations=graph_annotations, voxel_size=(1, 1, 1) + ) with viewer.txn() as s: - s.layers.append(name='graph', layer=graph_layer) + s.layers.append(name="graph", layer=graph_layer) url = str(viewer) print(url) @@ -215,4 +210,4 @@ def open_dataset(f, ds): webbrowser.open_new(url) print("Press ENTER to quit") -input() \ No newline at end of file +input() diff --git a/tests/test_client.py b/tests/test_client.py index d763a4cd..a9dad140 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,14 +1,11 @@ import daisy import unittest import multiprocessing as mp -from daisy.messages import ( - AcquireBlock, ReleaseBlock, SendBlock, - ExceptionMessage) +from daisy.messages import AcquireBlock, ReleaseBlock, SendBlock, ExceptionMessage from daisy.tcp import TCPServer class TestClient(unittest.TestCase): - def run_test_server(self, block, conn): server = TCPServer() conn.send(server.address) @@ -40,18 +37,14 @@ def run_test_server(self, block, conn): def test_basic(self): roi = daisy.Roi((0, 0, 0), (10, 10, 10)) task_id = 1 - block = daisy.Block( - roi, roi, roi, - block_id=1, - task_id=task_id) + block = daisy.Block(roi, roi, roi, block_id=1, task_id=task_id) parent_conn, child_conn = mp.Pipe() - server_process = mp.Process(target=self.run_test_server, - args=(block, child_conn)) + server_process = mp.Process( + target=self.run_test_server, args=(block, child_conn) + ) server_process.start() host, port = parent_conn.recv() - context = daisy.Context( - hostname=host, port=port, - task_id=task_id, worker_id=1) + context = daisy.Context(hostname=host, port=port, task_id=task_id, worker_id=1) client = daisy.Client(context=context) with client.acquire_block() as block: block.status = daisy.BlockStatus.SUCCESS diff --git a/tests/test_dependency_graph.py b/tests/test_dependency_graph.py index 429e01c3..5d447ab7 100644 --- a/tests/test_dependency_graph.py +++ b/tests/test_dependency_graph.py @@ -131,7 +131,6 @@ def test_get_subgraph_blocks(): def test_shrink_downstream_upstream_equivalence(): - total_roi = Roi((0,), (100,)) read_roi = Roi((0,), (7,)) write_roi = Roi((1,), (5,)) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 339f028d..c6b4f555 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -53,6 +53,7 @@ def task_1d(tmpdir): check_function=None, ) + @pytest.fixture def task_no_conflicts(tmpdir): # block ids: @@ -155,6 +156,7 @@ def overlapping_tasks(): ) return task_1, task_2 + def test_simple_no_conflicts(task_no_conflicts): scheduler = Scheduler([task_no_conflicts]) @@ -166,7 +168,7 @@ def test_simple_no_conflicts(task_no_conflicts): assert block.read_roi == expected_block.read_roi assert block.write_roi == expected_block.write_roi assert block.block_id == expected_block.block_id - + block = scheduler.acquire_block(task_no_conflicts.task_id) expected_block = Block( @@ -176,6 +178,7 @@ def test_simple_no_conflicts(task_no_conflicts): assert block.write_roi == expected_block.write_roi assert block.block_id == expected_block.block_id + def test_simple_acquire_block(task_1d): scheduler = Scheduler([task_1d]) block = scheduler.acquire_block(task_1d.task_id) diff --git a/tests/test_tcp.py b/tests/test_tcp.py index 9805eb59..5314b33a 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -4,9 +4,7 @@ class TestTCPConnections(unittest.TestCase): - def test_single_connection(self): - # create a server server = TCPServer() host, port = server.address diff --git a/tests/tmpdir_test.py b/tests/tmpdir_test.py index a7cf0bbd..9893689c 100644 --- a/tests/tmpdir_test.py +++ b/tests/tmpdir_test.py @@ -34,7 +34,8 @@ class TmpDirTestCase(unittest.TestCase): tearDownClass should explicitly call the ``super`` method in the method definition. """ - _output_root = '' + + _output_root = "" _cleanup = True def path_to(self, *args): @@ -48,7 +49,8 @@ def path_to_cls(cls, *args): def setUpClass(cls): timestamp = datetime.now().isoformat() cls._output_root = mkdtemp( - prefix='daisy_{}_{}_'.format(cls.__name__, timestamp)) + prefix="daisy_{}_{}_".format(cls.__name__, timestamp) + ) def setUp(self): os.mkdir(self.path_to()) @@ -59,9 +61,9 @@ def tearDown(self): if self._cleanup: shutil.rmtree(path) else: - warn('Directory {} was not deleted'.format(path)) + warn("Directory {} was not deleted".format(path)) except OSError as e: - if '[Errno 2]' in str(e): + if "[Errno 2]" in str(e): pass else: raise @@ -72,13 +74,14 @@ def tearDownClass(cls): if cls._cleanup: os.rmdir(cls.path_to_cls()) else: - warn('Directory {} was not deleted'.format(cls.path_to_cls())) + warn("Directory {} was not deleted".format(cls.path_to_cls())) except OSError as e: - if '[Errno 39]' in str(e): + if "[Errno 39]" in str(e): warn( - 'Directory {} could not be deleted as it still had data ' - 'in it'.format(cls.path_to_cls())) - elif '[Errno 2]' in str(e): + "Directory {} could not be deleted as it still had data " + "in it".format(cls.path_to_cls()) + ) + elif "[Errno 2]" in str(e): pass else: raise