Skip to content

Commit

Permalink
Use temporary scripts to launch Slurm array jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
takluyver committed Sep 13, 2024
1 parent 38bcb05 commit adf5e04
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 42 deletions.
8 changes: 2 additions & 6 deletions damnit/backend/extract_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,7 @@ def main(argv=None):
# Hide some logging from Kafka to make things more readable
logging.getLogger('kafka').setLevel(logging.WARNING)

if (run := args.run) == -1:
log.debug("Getting run number from Slurm array task ID")
run = int(os.environ['SLURM_ARRAY_TASK_ID'])

print(f"\n----- Processing r{run} (p{args.proposal}) -----", file=sys.stderr)
print(f"\n----- Processing r{args.run} (p{args.proposal}) -----", file=sys.stderr)
log.info(f"run_data={args.run_data}, match={args.match}")
if args.mock:
log.info("Using mock run object for testing")
Expand All @@ -281,7 +277,7 @@ def main(argv=None):
if args.update_vars:
extr.update_db_vars()

extr.extract_and_ingest(args.proposal, run,
extr.extract_and_ingest(args.proposal, args.run,
cluster=args.cluster_job,
run_data=RunData(args.run_data),
match=args.match,
Expand Down
85 changes: 49 additions & 36 deletions damnit/backend/extraction_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import sys
from contextlib import contextmanager
from ctypes import CDLL
from dataclasses import dataclass, replace
from itertools import groupby
from dataclasses import dataclass
from pathlib import Path
from secrets import token_hex
from threading import Thread

from extra_data.read_machinery import find_proposal
Expand All @@ -39,8 +39,6 @@ def default_slurm_partition():


def process_log_path(run, proposal, ctx_dir=Path('.'), create=True):
if run == -1:
run = "%a" # Slurm array task ID
p = ctx_dir.absolute() / 'process_logs' / f"r{run}-p{proposal}.out"
if create:
p.parent.mkdir(exist_ok=True)
Expand Down Expand Up @@ -77,6 +75,16 @@ def proposal_runs(proposal):
raw_dir = Path(find_proposal(proposal_name)) / "raw"
return set(int(p.stem[1:]) for p in raw_dir.glob("*"))

def batches(l, n):
start = 0
while True:
end = start + n
batch = l[start:end]
if not batch:
return
yield batch
start = end


@dataclass
class ExtractionRequest:
Expand Down Expand Up @@ -164,50 +172,55 @@ def sbatch_cmd(self, req: ExtractionRequest):

def submit_multi(self, reqs: list[ExtractionRequest], limit_running=30):
"""Submit multiple requests using Slurm job arrays.
Requests with only the run number different are grouped together.
limit_running controls how many can run simultaneously *within* each
group. Normally most/all runs will be in one group, so this will be
close to the overall limit.
"""
out = []
# run -1 tells extract_data to take the run # from the Slurm array task
for generic_req, req_group in groupby(reqs, key=lambda r: replace(r, run=-1)):
cmd = self.sbatch_cmd(generic_req) # -1 -> use Slurm array task id
runs = [req.run for req in req_group]
array_spec = self._abbrev_array_nums(runs) + f'%{limit_running}'
cmd += ['--array', array_spec]

assert len({r.cluster for r in reqs}) <= 1 # Don't mix cluster/non-cluster

# Array jobs are limited to 1001 in Slurm config (MaxArraySize)
for req_group in batches(reqs, 1000):
grpid = token_hex(8) # random unique string
scripts_dir = self.context_dir / '.tmp'
scripts_dir.mkdir(exist_ok=True)
if scripts_dir.stat().st_uid == os.getuid():
scripts_dir.chmod(0o777)

for i, req in enumerate(req_group):
script_file = scripts_dir / f'launch-{grpid}-{i}.sh'
log_path = process_log_path(req.run, req.proposal, self.context_dir)
script_file.write_text(
'rm "$0"\n' # Script cleans itself up
f'{shlex.join(req.python_cmd())} >>"{log_path}" 2>&1'
)
script_file.chmod(0o777)

script_expr = f".tmp/launch-{grpid}-$SLURM_ARRAY_TASK_ID.sh"
cmd = self.sbatch_array_cmd(script_expr, req_group, limit_running)
res = subprocess.run(
cmd, stdout=subprocess.PIPE, text=True, check=True, cwd=self.context_dir,
)
job_id, _, cluster = res.stdout.partition(';')
job_id = job_id.strip()
cluster = cluster.strip() or 'maxwell'
log.info("Launched Slurm (%s) job %s with array %s (%d runs) to run context file",
cluster, job_id, array_spec, len(runs))
log.info("Launched Slurm (%s) job array %s (%d runs) to run context file",
cluster, job_id, len(req_group))
out.append((job_id, cluster))

return out

@staticmethod
def _abbrev_array_nums(nums: list[int]) -> str:
range_starts, range_ends = [nums[0]], []
current_range_end = nums[0]
for r in nums[1:]:
if r > current_range_end + 1:
range_ends.append(current_range_end)
range_starts.append(r)
current_range_end = r
range_ends.append(current_range_end)

s_pieces = []
for start, end in zip(range_starts, range_ends):
if start == end:
s_pieces.append(str(start))
else:
s_pieces.append(f"{start}-{end}")

return ",".join(s_pieces)
def sbatch_array_cmd(self, script_expr, reqs, limit_running=30):
"""Make the sbatch command for an array job"""
req = reqs[0] # This should never be called with an empty list
return [
'sbatch', '--parsable',
*self._resource_opts(req.cluster),
# Slurm doesn't know the run number, so we redirect inside the job
'-o', '/dev/null',
'--open-mode=append',
'--job-name', f"p{req.proposal}-damnit",
'--array', f"0-{len(reqs)-1}%{limit_running}",
'--wrap', f'exec {script_expr}'
]

def execute_in_slurm(self, req: ExtractionRequest):
"""Run an extraction job in srun with live output"""
Expand Down

0 comments on commit adf5e04

Please sign in to comment.