Skip to content

Commit

Permalink
Gracefully handle empty task file
Browse files Browse the repository at this point in the history
  • Loading branch information
lgarrison committed Nov 20, 2024
1 parent 3bc8aef commit 3041917
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 122 deletions.
293 changes: 171 additions & 122 deletions disbatch/disBatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,112 +1002,154 @@ def poll(self):
self.taskSource.done()


# Given a task source (generating task command lines), parse the lines and
# produce a TaskInfo generator.
def taskGenerator(tasks):
tsx = 0 # "line number" of current task
taskCounter = 0 # next taskId
peCounters = {'START': 0, 'STOP': 0}
perEngineAllowed = True
prefix = suffix = b''

def peEndListTasks():
for when in ['START', 'STOP']:
yield TaskInfo(
peCounters[when], tsx, -1, b'#ENDLIST', '.per engine %s %d' % (when, peCounters[when]), kind='P'
)
class TaskGenerator:
"""
Given a task source (generating task command lines), parse the lines and
yield TaskInfos.
The generator also knows when it has no more user tasks, as indicated by
the `done` property.
"""

def __init__(self, tasks):
self.tasks = tasks
self._done = False
self._generator = self._task_generator()

# We look ahead by one task because we need to actually
# start constructing tasks to know if we are done.
self._buffer_next()

def __iter__(self):
return self

def __next__(self):
if self._exception:
raise self._exception

val = self._val
self._buffer_next()
return val

OK = True
while OK:
tsx += 1
def _buffer_next(self):
try:
t = next(tasks)
except StopIteration:
# Signals there will be no more tasks.
break
self._val = next(self._generator)
self._exception = None
except StopIteration as e:
self._exception = e
self._val = None

# Split on newlines.
#
# This allows tasks submitted through kvs with or without newlines,
# including multiple tasks per item, or from files (always with single
# trailing newline).
#
# Note that multiple lines in the same item get the same streamIndex,
# but this shouldn't be a problem. (Alternatively could increment tsx
# inside this loop instead.)
for t in t.splitlines():
if not t.startswith(b'#DISBATCH') and dbcomment.match(t):
# Comment or empty line, ignore
continue
@property
def done(self):
return self._done

def _task_generator(self):
tsx = 0 # "line number" of current task
taskCounter = 0 # next taskId
peCounters = {'START': 0, 'STOP': 0}
perEngineAllowed = True
prefix = suffix = b''

def peEndListTasks():
for when in ['START', 'STOP']:
yield TaskInfo(
peCounters[when], tsx, -1, b'#ENDLIST', '.per engine %s %d' % (when, peCounters[when]), kind='P'
)

m = dbperengine.match(t)
if m:
if not perEngineAllowed:
# One could imagine doing some sort of per-engine
# reset with each barrier, but that would get
# messy pretty quickly.
logger.error('Per-engine tasks not permitted after normal tasks.')
OK = False
break
when, cmd = m.groups()
when = when.decode('ascii')
cmd = prefix + cmd + suffix
yield TaskInfo(peCounters[when], tsx, -1, cmd, '.per engine %s %d' % (when, peCounters[when]), kind='P')
peCounters[when] += 1
continue
OK = True
while OK:
tsx += 1
try:
t = next(self.tasks)
except StopIteration:
# Signals there will be no more tasks.
break

m = dbprefix.match(t)
if m:
prefix = m.group(1)
continue
# Split on newlines.
#
# This allows tasks submitted through kvs with or without newlines,
# including multiple tasks per item, or from files (always with single
# trailing newline).
#
# Note that multiple lines in the same item get the same streamIndex,
# but this shouldn't be a problem. (Alternatively could increment tsx
# inside this loop instead.)
for t in t.splitlines():
if not t.startswith(b'#DISBATCH') and dbcomment.match(t):
# Comment or empty line, ignore
continue

m = dbsuffix.match(t)
if m:
suffix = m.group(1)
continue
m = dbperengine.match(t)
if m:
if not perEngineAllowed:
# One could imagine doing some sort of per-engine
# reset with each barrier, but that would get
# messy pretty quickly.
logger.error('Per-engine tasks not permitted after normal tasks.')
OK = False
break
when, cmd = m.groups()
when = when.decode('ascii')
cmd = prefix + cmd + suffix
yield TaskInfo(
peCounters[when], tsx, -1, cmd, '.per engine %s %d' % (when, peCounters[when]), kind='P'
)
peCounters[when] += 1
continue

if perEngineAllowed:
# Close out the per-engine task block.
perEngineAllowed = False
yield from peEndListTasks()

m = dbrepeat.match(t)
if m:
repeats, rx, step = int(m.group('repeat')), 0, 1
g = m.group('start')
if g:
rx = int(g)
g = m.group('step')
if g:
step = int(g)
logger.info('Processing repeat: %d %d %d', repeats, rx, step)
cmd = prefix + (m.group('command') or b'') + suffix
while repeats > 0:
yield TaskInfo(taskCounter, tsx, rx, cmd, '.task')
taskCounter += 1
rx += step
repeats -= 1
continue
m = dbprefix.match(t)
if m:
prefix = m.group(1)
continue

m = dbbarrier.match(t)
if m:
check, bKey = m.groups()
kind = 'C' if check else 'B'
yield TaskInfo(taskCounter, tsx, -1, t, '.barrier', kind=kind, bKey=bKey)
taskCounter += 1
continue
m = dbsuffix.match(t)
if m:
suffix = m.group(1)
continue

if t.startswith(b'#DISBATCH '):
logger.error('Unknown #DISBATCH directive: %s', t)
else:
yield TaskInfo(taskCounter, tsx, -1, prefix + t + suffix, '.task')
taskCounter += 1
if perEngineAllowed:
# Close out the per-engine task block.
perEngineAllowed = False
yield from peEndListTasks()

m = dbrepeat.match(t)
if m:
repeats, rx, step = int(m.group('repeat')), 0, 1
g = m.group('start')
if g:
rx = int(g)
g = m.group('step')
if g:
step = int(g)
logger.info('Processing repeat: %d %d %d', repeats, rx, step)
cmd = prefix + (m.group('command') or b'') + suffix
while repeats > 0:
yield TaskInfo(taskCounter, tsx, rx, cmd, '.task')
taskCounter += 1
rx += step
repeats -= 1
continue

if perEngineAllowed:
# Handle edge case of no tasks.
yield from peEndListTasks()
m = dbbarrier.match(t)
if m:
check, bKey = m.groups()
kind = 'C' if check else 'B'
yield TaskInfo(taskCounter, tsx, -1, t, '.barrier', kind=kind, bKey=bKey)
taskCounter += 1
continue

logger.info('Processed %d tasks.', taskCounter)
if t.startswith(b'#DISBATCH '):
logger.error('Unknown #DISBATCH directive: %s', t)
else:
yield TaskInfo(taskCounter, tsx, -1, prefix + t + suffix, '.task')
taskCounter += 1

self._done = True
if perEngineAllowed:
# Handle edge case of no tasks.
yield from peEndListTasks()

logger.info('Processed %d tasks.', taskCounter)


def statusTaskFilter(tasks, status, retry=False, force=False):
Expand Down Expand Up @@ -2368,7 +2410,10 @@ def shutdown(s=None, f=None):
'Task source name: ' + taskSource.name.decode('utf-8')
) # TODO: Think about the decoding a bit more?

tasks = taskGenerator(taskSource)
tasks = TaskGenerator(taskSource)

if tasks.done:
print('No tasks to run.', file=sys.stderr)

if args.resume_from:
tasks = statusTaskFilter(tasks, parseStatusFiles(*args.resume_from), args.retry, args.force_resume)
Expand All @@ -2394,33 +2439,37 @@ def shutdown(s=None, f=None):
)
)

if not args.startup_only:
# Is there a cleaner way to do this?
extraArgs = []
argsD = args.__dict__
for name in commonContextArgs:
v = argsD[name]
if v is None:
continue
aName = '--' + name.replace('_', '-')
if isinstance(v, bool):
if v:
extraArgs.append(aName)
elif isinstance(v, list):
for e in v:
extraArgs.extend([aName, str(e)])
else:
extraArgs.extend([aName, str(v)])
subContext = None
if not tasks.done:
# Don't start the context if there are no tasks to run,
# the kvs is about to shut down so the context won't be able to connect.

if not args.startup_only:
# Is there a cleaner way to do this?
extraArgs = []
argsD = args.__dict__
for name in commonContextArgs:
v = argsD[name]
if v is None:
continue
aName = '--' + name.replace('_', '-')
if isinstance(v, bool):
if v:
extraArgs.append(aName)
elif isinstance(v, list):
for e in v:
extraArgs.extend([aName, str(e)])
else:
extraArgs.extend([aName, str(v)])

subContext = SUB.Popen(
[DbUtilPath] + extraArgs,
stdin=open(os.devnull),
stdout=open(uniqueId + '_context_wrap.out', 'w'),
close_fds=True,
)
else:
print('Run this script to add compute contexts:\n ' + DbUtilPath)
subContext = None
subContext = SUB.Popen(
[DbUtilPath] + extraArgs,
stdin=open(os.devnull),
stdout=open(uniqueId + '_context_wrap.out', 'w'),
close_fds=True,
)
else:
print('Run this script to add compute contexts:\n ' + DbUtilPath)

driver = Driver(kvs, dbInfo, tasks, getattr(taskSource, 'resultkey', resultKey))
try:
Expand Down
4 changes: 4 additions & 0 deletions tests/test_slurm/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ salloc -n 2 disBatch Tasks
[[ -f A.txt && -f B.txt && -f C.txt ]]
success=$?

# Test empty task file
salloc -n 2 disBatch /dev/null
success=$((success + $?))

cd - > /dev/null

if [[ $success -eq 0 ]]; then
Expand Down
4 changes: 4 additions & 0 deletions tests/test_ssh/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ disBatch -s localhost:2 Tasks
[[ -f A.txt && -f B.txt && -f C.txt ]]
success=$?

# Test empty task file
disBatch -s localhost:2 /dev/null
success=$((success + $?))

cd - > /dev/null

if [[ $success -eq 0 ]]; then
Expand Down

0 comments on commit 3041917

Please sign in to comment.