diff --git a/disbatch/disBatch.py b/disbatch/disBatch.py index 4dc70e7..9309d43 100644 --- a/disbatch/disBatch.py +++ b/disbatch/disBatch.py @@ -1002,112 +1002,154 @@ def poll(self): self.taskSource.done() -# Given a task source (generating task command lines), parse the lines and -# produce a TaskInfo generator. -def taskGenerator(tasks): - tsx = 0 # "line number" of current task - taskCounter = 0 # next taskId - peCounters = {'START': 0, 'STOP': 0} - perEngineAllowed = True - prefix = suffix = b'' - - def peEndListTasks(): - for when in ['START', 'STOP']: - yield TaskInfo( - peCounters[when], tsx, -1, b'#ENDLIST', '.per engine %s %d' % (when, peCounters[when]), kind='P' - ) +class TaskGenerator: + """ + Given a task source (generating task command lines), parse the lines and + yield TaskInfos. + + The generator also knows when it has no more user tasks, as indicated by + the `done` property. + """ + + def __init__(self, tasks): + self._tasks = tasks + self._done = False + self._generator = self._task_generator() + + # We look ahead by one task because we need to actually + # start constructing tasks to know if we are done. + self._buffer_next() + + def __iter__(self): + return self + + def __next__(self): + if self._exception: + raise self._exception + + val = self._val + self._buffer_next() + return val - OK = True - while OK: - tsx += 1 + def _buffer_next(self): try: - t = next(tasks) - except StopIteration: - # Signals there will be no more tasks. - break + self._val = next(self._generator) + self._exception = None + except StopIteration as e: + self._exception = e + self._val = None - # Split on newlines. - # - # This allows tasks submitted through kvs with or without newlines, - # including multiple tasks per item, or from files (always with single - # trailing newline). - # - # Note that multiple lines in the same item get the same streamIndex, - # but this shouldn't be a problem. (Alternatively could increment tsx - # inside this loop instead.) - for t in t.splitlines(): - if not t.startswith(b'#DISBATCH') and dbcomment.match(t): - # Comment or empty line, ignore - continue + @property + def done(self): + return self._done + + def _task_generator(self): + tsx = 0 # "line number" of current task + taskCounter = 0 # next taskId + peCounters = {'START': 0, 'STOP': 0} + perEngineAllowed = True + prefix = suffix = b'' + + def peEndListTasks(): + for when in ['START', 'STOP']: + yield TaskInfo( + peCounters[when], tsx, -1, b'#ENDLIST', '.per engine %s %d' % (when, peCounters[when]), kind='P' + ) - m = dbperengine.match(t) - if m: - if not perEngineAllowed: - # One could imagine doing some sort of per-engine - # reset with each barrier, but that would get - # messy pretty quickly. - logger.error('Per-engine tasks not permitted after normal tasks.') - OK = False - break - when, cmd = m.groups() - when = when.decode('ascii') - cmd = prefix + cmd + suffix - yield TaskInfo(peCounters[when], tsx, -1, cmd, '.per engine %s %d' % (when, peCounters[when]), kind='P') - peCounters[when] += 1 - continue + OK = True + while OK: + tsx += 1 + try: + t = next(self._tasks) + except StopIteration: + # Signals there will be no more tasks. + break - m = dbprefix.match(t) - if m: - prefix = m.group(1) - continue + # Split on newlines. + # + # This allows tasks submitted through kvs with or without newlines, + # including multiple tasks per item, or from files (always with single + # trailing newline). + # + # Note that multiple lines in the same item get the same streamIndex, + # but this shouldn't be a problem. (Alternatively could increment tsx + # inside this loop instead.) + for t in t.splitlines(): + if not t.startswith(b'#DISBATCH') and dbcomment.match(t): + # Comment or empty line, ignore + continue - m = dbsuffix.match(t) - if m: - suffix = m.group(1) - continue + m = dbperengine.match(t) + if m: + if not perEngineAllowed: + # One could imagine doing some sort of per-engine + # reset with each barrier, but that would get + # messy pretty quickly. + logger.error('Per-engine tasks not permitted after normal tasks.') + OK = False + break + when, cmd = m.groups() + when = when.decode('ascii') + cmd = prefix + cmd + suffix + yield TaskInfo( + peCounters[when], tsx, -1, cmd, '.per engine %s %d' % (when, peCounters[when]), kind='P' + ) + peCounters[when] += 1 + continue - if perEngineAllowed: - # Close out the per-engine task block. - perEngineAllowed = False - yield from peEndListTasks() - - m = dbrepeat.match(t) - if m: - repeats, rx, step = int(m.group('repeat')), 0, 1 - g = m.group('start') - if g: - rx = int(g) - g = m.group('step') - if g: - step = int(g) - logger.info('Processing repeat: %d %d %d', repeats, rx, step) - cmd = prefix + (m.group('command') or b'') + suffix - while repeats > 0: - yield TaskInfo(taskCounter, tsx, rx, cmd, '.task') - taskCounter += 1 - rx += step - repeats -= 1 - continue + m = dbprefix.match(t) + if m: + prefix = m.group(1) + continue - m = dbbarrier.match(t) - if m: - check, bKey = m.groups() - kind = 'C' if check else 'B' - yield TaskInfo(taskCounter, tsx, -1, t, '.barrier', kind=kind, bKey=bKey) - taskCounter += 1 - continue + m = dbsuffix.match(t) + if m: + suffix = m.group(1) + continue - if t.startswith(b'#DISBATCH '): - logger.error('Unknown #DISBATCH directive: %s', t) - else: - yield TaskInfo(taskCounter, tsx, -1, prefix + t + suffix, '.task') - taskCounter += 1 + if perEngineAllowed: + # Close out the per-engine task block. + perEngineAllowed = False + yield from peEndListTasks() + + m = dbrepeat.match(t) + if m: + repeats, rx, step = int(m.group('repeat')), 0, 1 + g = m.group('start') + if g: + rx = int(g) + g = m.group('step') + if g: + step = int(g) + logger.info('Processing repeat: %d %d %d', repeats, rx, step) + cmd = prefix + (m.group('command') or b'') + suffix + while repeats > 0: + yield TaskInfo(taskCounter, tsx, rx, cmd, '.task') + taskCounter += 1 + rx += step + repeats -= 1 + continue - if perEngineAllowed: - # Handle edge case of no tasks. - yield from peEndListTasks() + m = dbbarrier.match(t) + if m: + check, bKey = m.groups() + kind = 'C' if check else 'B' + yield TaskInfo(taskCounter, tsx, -1, t, '.barrier', kind=kind, bKey=bKey) + taskCounter += 1 + continue - logger.info('Processed %d tasks.', taskCounter) + if t.startswith(b'#DISBATCH '): + logger.error('Unknown #DISBATCH directive: %s', t) + else: + yield TaskInfo(taskCounter, tsx, -1, prefix + t + suffix, '.task') + taskCounter += 1 + + self._done = True + if perEngineAllowed: + # Handle edge case of no tasks. + yield from peEndListTasks() + + logger.info('Processed %d tasks.', taskCounter) def statusTaskFilter(tasks, status, retry=False, force=False): @@ -2368,7 +2410,10 @@ def shutdown(s=None, f=None): 'Task source name: ' + taskSource.name.decode('utf-8') ) # TODO: Think about the decoding a bit more? - tasks = taskGenerator(taskSource) + tasks = TaskGenerator(taskSource) + + if tasks.done: + print('No tasks to run.', file=sys.stderr) if args.resume_from: tasks = statusTaskFilter(tasks, parseStatusFiles(*args.resume_from), args.retry, args.force_resume) @@ -2394,33 +2439,37 @@ def shutdown(s=None, f=None): ) ) - if not args.startup_only: - # Is there a cleaner way to do this? - extraArgs = [] - argsD = args.__dict__ - for name in commonContextArgs: - v = argsD[name] - if v is None: - continue - aName = '--' + name.replace('_', '-') - if isinstance(v, bool): - if v: - extraArgs.append(aName) - elif isinstance(v, list): - for e in v: - extraArgs.extend([aName, str(e)]) - else: - extraArgs.extend([aName, str(v)]) + subContext = None + if not tasks.done: + # Don't start the context if there are no tasks to run, + # the kvs is about to shut down so the context won't be able to connect. + + if not args.startup_only: + # Is there a cleaner way to do this? + extraArgs = [] + argsD = args.__dict__ + for name in commonContextArgs: + v = argsD[name] + if v is None: + continue + aName = '--' + name.replace('_', '-') + if isinstance(v, bool): + if v: + extraArgs.append(aName) + elif isinstance(v, list): + for e in v: + extraArgs.extend([aName, str(e)]) + else: + extraArgs.extend([aName, str(v)]) - subContext = SUB.Popen( - [DbUtilPath] + extraArgs, - stdin=open(os.devnull), - stdout=open(uniqueId + '_context_wrap.out', 'w'), - close_fds=True, - ) - else: - print('Run this script to add compute contexts:\n ' + DbUtilPath) - subContext = None + subContext = SUB.Popen( + [DbUtilPath] + extraArgs, + stdin=open(os.devnull), + stdout=open(uniqueId + '_context_wrap.out', 'w'), + close_fds=True, + ) + else: + print('Run this script to add compute contexts:\n ' + DbUtilPath) driver = Driver(kvs, dbInfo, tasks, getattr(taskSource, 'resultkey', resultKey)) try: diff --git a/tests/test_slurm/run.sh b/tests/test_slurm/run.sh index c6a5487..edde8b0 100755 --- a/tests/test_slurm/run.sh +++ b/tests/test_slurm/run.sh @@ -12,6 +12,10 @@ salloc -n 2 disBatch Tasks [[ -f A.txt && -f B.txt && -f C.txt ]] success=$? +# Test empty task file +salloc -n 2 disBatch /dev/null +success=$((success + $?)) + cd - > /dev/null if [[ $success -eq 0 ]]; then diff --git a/tests/test_ssh/run.sh b/tests/test_ssh/run.sh index fab0f48..e767055 100755 --- a/tests/test_ssh/run.sh +++ b/tests/test_ssh/run.sh @@ -12,6 +12,10 @@ disBatch -s localhost:2 Tasks [[ -f A.txt && -f B.txt && -f C.txt ]] success=$? +# Test empty task file +disBatch -s localhost:2 /dev/null +success=$((success + $?)) + cd - > /dev/null if [[ $success -eq 0 ]]; then