Skip to content

Commit

Permalink
scheduler: better and more tests (#57)
Browse files Browse the repository at this point in the history
* scheduler: better tests

Co-authored-by: Alexander Goscinski <[email protected]>

---------

Co-authored-by: Alexander Goscinski <[email protected]>
  • Loading branch information
khsrali and agoscinski authored Aug 20, 2024
1 parent a147f87 commit 151e5a6
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 42 deletions.
3 changes: 2 additions & 1 deletion .firecrest-demo-config.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
"temp_directory": "",
"small_file_size_mb": 5.0,
"workdir": "",
"api_version": "1.16.0"
"api_version": "1.16.0",
"builder_metadata_options_custom_scheduler_commands": []
}
9 changes: 6 additions & 3 deletions aiida_firecrest/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,17 @@ def get_jobs(
try:
for page_iter in itertools.count():
results += transport._client.poll_active(
transport._machine, jobs, page_number=page_iter
transport._machine,
jobs,
page_number=page_iter,
page_size=self._DEFAULT_PAGE_SIZE,
)
if len(results) < self._DEFAULT_PAGE_SIZE * (page_iter + 1):
break
except FirecrestException as exc:
# firecrest returns error if the job is completed
# TODO: check what type of error is returned and handle it properly
if "Invalid job id specified" not in str(exc):
if "Invalid job id" not in str(exc):
# firecrest returns error if the job is completed, while aiida expect a silent return
raise SchedulerError(str(exc)) from exc
job_list = []
for raw_result in results:
Expand Down
8 changes: 6 additions & 2 deletions aiida_firecrest/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ def __init__(
# aiida-core/src/aiida/orm/utils/remote:clean_remote()
self._is_open = True

self.checksum_check = False

def __str__(self) -> str:
"""Return the name of the plugin."""
return self.__class__.__name__
Expand Down Expand Up @@ -723,7 +725,8 @@ def getfile(
down_obj = self._client.external_download(self._machine, str(remote))
down_obj.finish_download(local)

self._validate_checksum(local, remote)
if self.checksum_check:
self._validate_checksum(local, remote)

def _validate_checksum(
self, localpath: str | Path, remotepath: str | FcPath
Expand Down Expand Up @@ -965,7 +968,8 @@ def putfile(
)
up_obj.finish_upload()

self._validate_checksum(localpath, str(remote))
if self.checksum_check:
self._validate_checksum(localpath, str(remote))

def payoff(self, path: str | FcPath | Path) -> bool:
"""
Expand Down
62 changes: 49 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, field
import hashlib
import itertools
import json
import os
from pathlib import Path
import shutil
import stat
from typing import Any, Callable
from typing import Any, Callable, ClassVar
from unittest.mock import MagicMock
from urllib.parse import urlparse

from aiida import orm
import firecrest
import firecrest.path
import pytest
import requests


class Values:
_DEFAULT_PAGE_SIZE: int = 25
class Slurm:
"""Save the submitted job ids for testing purposes."""

all_jobs: ClassVar[list] = []


@pytest.fixture
Expand Down Expand Up @@ -75,7 +77,12 @@ def submit(
raise DeprecationWarning("local_file is not supported")

if script_remote_path and not Path(script_remote_path).exists():
raise FileNotFoundError(f"File {script_remote_path} does not exist")
# Firecrest raises FirecrestException instead of FileNotFoundError
mock_response = MagicMock()
mock_response.status_code = 999 # I don't really know
mock_response.json.return_value = {"error": "Mock error message"}
raise firecrest.FirecrestException(mock_response)

job_id = next(self.job_id_generator)

# Filter out lines starting with '#SBATCH'
Expand All @@ -98,9 +105,18 @@ def submit(
os.chdir(Path(script_remote_path).parent)
os.system(command)

Slurm.all_jobs.append(job_id)

return {"jobid": job_id}

def poll_active(self, machine: str, jobs: list[str], page_number: int = 0):
def cancel(self, machine: str, job_id: str):
job_id = int(job_id)
if job_id in Slurm.all_jobs:
Slurm.all_jobs.remove(job_id)

def poll_active(
self, machine: str, jobs: list[str], page_number: int = 0, page_size: int = 25
):
response = []
# 12 satets are defined in firecrest
states = [
Expand All @@ -118,6 +134,11 @@ def poll_active(self, machine: str, jobs: list[str], page_number: int = 0):
"COMPLETING",
]
for i in range(len(jobs)):
if int(jobs[i]) not in Slurm.all_jobs:
mock_response = MagicMock()
mock_response.status_code = 999 # I don't really know
mock_response.json.return_value = {"error": "Invalid job id"}
raise firecrest.FirecrestException([mock_response])
response.append(
{
"job_data_err": "",
Expand All @@ -139,11 +160,7 @@ def poll_active(self, machine: str, jobs: list[str], page_number: int = 0):
}
)

return response[
page_number
* Values._DEFAULT_PAGE_SIZE : (page_number + 1)
* Values._DEFAULT_PAGE_SIZE
]
return response[page_number * page_size : (page_number + 1) * page_size]

def whoami(self, machine: str):
assert machine == "MACHINE_NAME"
Expand Down Expand Up @@ -373,7 +390,22 @@ def __init__(self, *args, **kwargs):

@dataclass
class ComputerFirecrestConfig:
"""Configuration of a computer using FirecREST as transport plugin."""
"""Configuration of a computer using FirecREST as transport plugin.
:param url: The URL of the FirecREST server.
:param token_uri: The URI to receive tokens.
:param client_id: The client ID for the client credentials.
:param client_secret: The client secret for the client credentials.
:param compute_resource: The name of the compute resource. This is the name of the machine.
:param temp_directory: A temporary directory on the machine for transient zip files.
:param workdir: The aiida working directory on the machine.
:param api_version: The version of the FirecREST API.
:param builder_metadata_options_custom_scheduler_commands: A list of custom
scheduler commands when submitting a job, for example
["#SBATCH --account=mr32",
"#SBATCH --constraint=mc",
"#SBATCH --mem=10K"].
:param small_file_size_mb: The maximum file size for direct upload & download."""

url: str
token_uri: str
Expand All @@ -384,6 +416,9 @@ class ComputerFirecrestConfig:
workdir: str
api_version: str
small_file_size_mb: float = 1.0
builder_metadata_options_custom_scheduler_commands: list[str] = field(
default_factory=list
)


class RequestTelemetry:
Expand Down Expand Up @@ -514,4 +549,5 @@ def firecrest_config(
small_file_size_mb=1.0,
temp_directory=str(_temp_directory),
api_version="2",
builder_metadata_options_custom_scheduler_commands=[],
)
7 changes: 6 additions & 1 deletion tests/test_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ def _no_retries():
manage.get_config().set_option(MAX_ATTEMPTS_OPTION, max_attempts)


@pytest.mark.timeout(180)
@pytest.mark.usefixtures("aiida_profile_clean", "no_retries")
def test_calculation_basic(firecrest_computer: orm.Computer):
def test_calculation_basic(firecrest_computer: orm.Computer, firecrest_config):
"""Test running a simple `arithmetic.add` calculation."""
code = orm.InstalledCode(
label="test_code",
Expand All @@ -35,6 +36,10 @@ def test_calculation_basic(firecrest_computer: orm.Computer):
builder = code.get_builder()
builder.x = orm.Int(1)
builder.y = orm.Int(2)
custom_scheduler_commands = "\n".join(
firecrest_config.builder_metadata_options_custom_scheduler_commands
)
builder.metadata.options.custom_scheduler_commands = custom_scheduler_commands

_, node = engine.run_get_node(builder)
assert node.is_finished_ok
Expand Down
122 changes: 100 additions & 22 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,125 @@
from pathlib import Path
import textwrap
from time import sleep

from aiida import orm
from aiida.schedulers import SchedulerError
from aiida.schedulers.datastructures import CodeRunMode, JobTemplate
import pytest

from aiida_firecrest.scheduler import FirecrestScheduler
from conftest import Values


@pytest.mark.usefixtures("aiida_profile_clean")
def test_submit_job(firecrest_computer: orm.Computer, tmp_path: Path):
def test_submit_job(firecrest_computer: orm.Computer, firecrest_config, tmpdir: Path):
"""Test submitting a job to the scheduler.
Note: this test relies on a functional transport.put() method."""

transport = firecrest_computer.get_transport()
scheduler = FirecrestScheduler()
scheduler.set_transport(transport)

with pytest.raises(FileNotFoundError):
scheduler.submit_job(transport.getcwd(), "unknown.sh")
# raise error if file not found
with pytest.raises(SchedulerError):
scheduler.submit_job(firecrest_config.workdir, "unknown.sh")

custom_scheduler_commands = "\n ".join(
firecrest_config.builder_metadata_options_custom_scheduler_commands
)

shell_script = f"""
#!/bin/bash
#SBATCH --no-requeue
#SBATCH --job-name="aiida-1928"
#SBATCH --get-user-env
#SBATCH --output=_scheduler-stdout.txt
#SBATCH --error=_scheduler-stderr.txt
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
{custom_scheduler_commands}
echo 'hello world'
"""

dedented_script = textwrap.dedent(shell_script).strip()
Path(tmpdir / "job.sh").write_text(dedented_script)
remote_ = transport._cwd.joinpath(firecrest_config.workdir, "job.sh")
transport.put(tmpdir / "job.sh", remote_)

_script = Path(tmp_path / "job.sh")
_script.write_text("#!/bin/bash\n\necho 'hello world'")
job_id = scheduler.submit_job(firecrest_config.workdir, "job.sh")

job_id = scheduler.submit_job(transport.getcwd(), _script)
# this is how aiida expects the job_id to be returned
assert isinstance(job_id, str)


@pytest.mark.timeout(180)
@pytest.mark.usefixtures("aiida_profile_clean")
def test_get_jobs(firecrest_computer: orm.Computer):
def test_get_and_kill_jobs(
firecrest_computer: orm.Computer, firecrest_config, tmpdir: Path
):
"""Test getting and killing jobs from the scheduler.
We test the two together for performance reasons, as this test might run against
a real server and we don't want to leave parasitic jobs behind.
also less billing for the user.
Note: this test relies on a functional transport.put() method.
"""
import time

transport = firecrest_computer.get_transport()
scheduler = FirecrestScheduler()
scheduler.set_transport(transport)

# test pagaination
scheduler._DEFAULT_PAGE_SIZE = 2
Values._DEFAULT_PAGE_SIZE = 2
# verify that no error is raised in the case of an invalid job id 000
scheduler.get_jobs(["000"])

custom_scheduler_commands = "\n ".join(
firecrest_config.builder_metadata_options_custom_scheduler_commands
)
shell_script = f"""
#!/bin/bash
#SBATCH --no-requeue
#SBATCH --job-name="aiida-1929"
#SBATCH --get-user-env
#SBATCH --output=_scheduler-stdout.txt
#SBATCH --error=_scheduler-stderr.txt
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
{custom_scheduler_commands}
sleep 180
"""

joblist = ["111", "222", "333", "444", "555"]
joblist = []
dedented_script = textwrap.dedent(shell_script).strip()
Path(tmpdir / "job.sh").write_text(dedented_script)
remote_ = transport._cwd.joinpath(firecrest_config.workdir, "job.sh")
transport.put(tmpdir / "job.sh", remote_)

for _ in range(5):
joblist.append(scheduler.submit_job(firecrest_config.workdir, "job.sh"))

# test pagaination is working
scheduler._DEFAULT_PAGE_SIZE = 2
result = scheduler.get_jobs(joblist)
assert len(result) == 5
for i in range(5):
assert result[i].job_id == str(joblist[i])
assert result[i].job_id in joblist
# TODO: one could check states as well

# test kill jobs
for jobid in joblist:
scheduler.kill_job(jobid)

# sometimes it takes time for the server to actually kill the jobs
timeout_kill = 5 # seconds
start_time = time.time()
while time.time() - start_time < timeout_kill:
result = scheduler.get_jobs(joblist)
if not len(result):
break
sleep(0.5)

assert not len(result)


def test_write_script_full():
# to avoid false positive (overwriting on existing file),
Expand All @@ -68,9 +146,9 @@ def test_write_script_full():
#SBATCH --mem=1
test_command
"""
expectaion_flat = "\n".join(line.strip() for line in expectaion.splitlines()).strip(
"\n"
)
expectation_flat = "\n".join(
line.strip() for line in expectaion.splitlines()
).strip("\n")
scheduler = FirecrestScheduler()
template = JobTemplate(
{
Expand Down Expand Up @@ -98,7 +176,7 @@ def test_write_script_full():
}
)
try:
assert scheduler.get_submit_script(template).rstrip() == expectaion_flat
assert scheduler.get_submit_script(template).rstrip() == expectation_flat
except AssertionError:
print(scheduler.get_submit_script(template).rstrip())
print(expectaion)
Expand All @@ -116,9 +194,9 @@ def test_write_script_minimal():
#SBATCH --ntasks-per-node=1
"""

expectaion_flat = "\n".join(line.strip() for line in expectaion.splitlines()).strip(
"\n"
)
expectation_flat = "\n".join(
line.strip() for line in expectaion.splitlines()
).strip("\n")
scheduler = FirecrestScheduler()
template = JobTemplate(
{
Expand All @@ -131,7 +209,7 @@ def test_write_script_minimal():
)

try:
assert scheduler.get_submit_script(template).rstrip() == expectaion_flat
assert scheduler.get_submit_script(template).rstrip() == expectation_flat
except AssertionError:
print(scheduler.get_submit_script(template).rstrip())
print(expectaion)
Expand Down
Loading

0 comments on commit 151e5a6

Please sign in to comment.