diff --git a/rl/cli/main.py b/rl/cli/main.py index b458044..20ed367 100644 --- a/rl/cli/main.py +++ b/rl/cli/main.py @@ -275,7 +275,7 @@ def create_batch_job(sbatch_args, name, job_time): ] subprocess.run(sbatch_args, check=True) - job_node, job_id = None, None + curr_job: JobInfo with rich.progress.Progress(transient=True) as progress: # noinspection PyTypeChecker task = progress.add_task( @@ -289,29 +289,22 @@ def create_batch_job(sbatch_args, name, job_time): "Job not found when checking status with squeue; what happened?" ) if curr_job.state == JobState.RUNNING: - job_node, job_id = curr_job.nodes, curr_job.job_id progress.update(task, completed=1) break - if "[" in job_node or "," in job_node: + if len(curr_job.nodes) > 1: rich.print( - f"[green]Job {job_id} started on nodes {job_node}. Pick one to ssh into.[/green]" + f"[green]Job {curr_job.job_id} started on nodes {', '.join(curr_job.nodes)}. SSHing into first node...[/green]" ) - # Get the user's choice of node - job_node = click.prompt("Node", type=str) else: rich.print( - f"[green]Job {job_id} started on node {job_node}. SSHing into node...[/green]" + f"[green]Job {curr_job.job_id} started on node {curr_job.nodes[0]}. SSHing into node...[/green]" ) - ssh_args = [ - SSH_PATH, - job_node, - ] - subprocess.run(ssh_args) + _ssh_within_sherlock(curr_job.nodes[0]) if click.confirm("Left the job, do you want to cancel it?"): - subprocess.run(["scancel", job_id]) + subprocess.run(["scancel", curr_job.job_id]) rich.print("[red]Job ended[/red]") else: - rich.print(f"[green]Job {job_id} will continue running[/green]") + rich.print(f"[green]Job {curr_job.job_id} will continue running[/green]") @_must_run_on_sherlock @@ -457,11 +450,9 @@ def _ssh_within_sherlock(node: str): if not node: node = _select_node() rich.print(f"[green]SSHing into {node}[/green]") - # When sshing in, we want to try to tmux attach and if that fails, just open a shell + # When SSHing in, we want to try to tmux attach and if that fails, just open a shell run_command = "tmux attach || fish || bash" - ssh = pexpect.spawn(f"ssh {node} -t '{run_command}'") - threading.Thread(target=_resize_ssh, args=(ssh,)).start() - ssh.interact() + subprocess.run([SSH_PATH, node, "-t", run_command]) def _select_node() -> str: