diff --git a/.github/workflows/tag_and_publish.yml b/.github/workflows/tag_and_publish.yml index 90419da..1bc02e0 100644 --- a/.github/workflows/tag_and_publish.yml +++ b/.github/workflows/tag_and_publish.yml @@ -3,9 +3,7 @@ on: push: branches: - main -# Remove line 61 to enable automated semantic version bumps. -# Change line 67 from "if: false" to "if: true" to enable PyPI publishing. -# Requires that svc-aindscicomp be added as an admin to repo. + jobs: update_badges: runs-on: ubuntu-latest @@ -16,10 +14,10 @@ jobs: ref: ${{ env.DEFAULT_BRANCH }} fetch-depth: 0 token: ${{ secrets.SERVICE_TOKEN }} - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v3 with: - python-version: 3.8 + python-version: '3.10' - name: Install dependencies run: | python -m pip install -e .[dev] --no-cache-dir @@ -62,28 +60,31 @@ jobs: add: '["README.md"]' tag: needs: update_badges - if: ${{github.event.repository.name == 'aind-library-template'}} uses: AllenNeuralDynamics/aind-github-actions/.github/workflows/tag.yml@main secrets: SERVICE_TOKEN: ${{ secrets.SERVICE_TOKEN }} publish: - needs: tag - if: false runs-on: ubuntu-latest + needs: tag steps: - uses: actions/checkout@v3 - name: Pull latest changes run: git pull origin main - - name: Set up Python 3.8 - uses: actions/setup-python@v2 + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v2 + - name: Login to Github Packages + uses: docker/login-action@v2 with: - python-version: 3.8 - - name: Install dependencies - run: | - pip install --upgrade setuptools wheel twine build - python -m build - twine check dist/* - - name: Publish on PyPI - uses: pypa/gh-action-pypi-publish@release/v1 + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Build image and push to GitHub Container Registry + uses: docker/build-push-action@v3 with: - password: ${{ secrets.AIND_PYPI_TOKEN }} + # relative path to the place where source code with Dockerfile is located + context: . + push: true + tags: | + ghcr.io/allenneuraldynamics/aind-airflow-jobs:${{ needs.tag.outputs.new_version }} + ghcr.io/allenneuraldynamics/aind-airflow-jobs:latest diff --git a/.github/workflows/test_and_lint.yml b/.github/workflows/test_and_lint.yml index c8d832d..33ef8a9 100644 --- a/.github/workflows/test_and_lint.yml +++ b/.github/workflows/test_and_lint.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [ '3.8', '3.9', '3.10' ] + python-version: [ '3.10', '3.11' ] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..0211c50 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,8 @@ +from python:3.10-slim + +WORKDIR /app +ADD src ./src +ADD pyproject.toml . +ADD setup.py . + +RUN pip install . --no-cache-dir diff --git a/doc_template/source/conf.py b/doc_template/source/conf.py index 61c644c..1a120a3 100644 --- a/doc_template/source/conf.py +++ b/doc_template/source/conf.py @@ -1,12 +1,15 @@ """Configuration file for the Sphinx documentation builder.""" + # # For the full list of built-in configuration values, see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html +from datetime import date + # -- Path Setup -------------------------------------------------------------- -from os.path import dirname, abspath +from os.path import abspath, dirname from pathlib import Path -from datetime import date + from aind_airflow_jobs import __version__ as package_version INSTITUTE_NAME = "Allen Institute for Neural Dynamics" diff --git a/pyproject.toml b/pyproject.toml index e12d5a9..719cb53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "aind-airflow-jobs" description = "Generated from aind-library-template" license = {text = "MIT"} -requires-python = ">=3.7" +requires-python = ">=3.10" authors = [ {name = "Allen Institute for Neural Dynamics"} ] @@ -17,6 +17,9 @@ readme = "README.md" dynamic = ["version"] dependencies = [ + "aind-slurm-rest", + "pydantic-settings", + "pydantic" ] [project.optional-dependencies] @@ -38,7 +41,7 @@ version = {attr = "aind_airflow_jobs.__version__"} [tool.black] line-length = 79 -target_version = ['py36'] +target_version = ['py310'] exclude = ''' ( diff --git a/src/aind_airflow_jobs/__init__.py b/src/aind_airflow_jobs/__init__.py index d0a8547..e2bb66c 100644 --- a/src/aind_airflow_jobs/__init__.py +++ b/src/aind_airflow_jobs/__init__.py @@ -1,2 +1,3 @@ -"""Init package""" +"""Package to manage airflow jobs""" + __version__ = "0.0.0" diff --git a/src/aind_airflow_jobs/submit_slurm_job.py b/src/aind_airflow_jobs/submit_slurm_job.py new file mode 100644 index 0000000..b24c181 --- /dev/null +++ b/src/aind_airflow_jobs/submit_slurm_job.py @@ -0,0 +1,397 @@ +"""Module to submit and monitor slurm jobs via the slurm rest api""" + +import binascii +import json +import logging +import sys +from argparse import ArgumentParser +from datetime import datetime +from enum import Enum +from pathlib import Path +from time import sleep +from typing import Dict, List +from uuid import uuid4 + +from aind_slurm_rest import ApiClient as Client +from aind_slurm_rest import Configuration as Config +from aind_slurm_rest import V0036JobSubmissionResponse +from aind_slurm_rest.api.slurm_api import SlurmApi +from aind_slurm_rest.models.v0036_job_properties import V0036JobProperties +from aind_slurm_rest.models.v0036_job_submission import V0036JobSubmission +from pydantic import Field, SecretStr +from pydantic_settings import BaseSettings, SettingsConfigDict + +logging.basicConfig(level="INFO") + + +class SlurmClientSettings(BaseSettings): + """Settings required to build slurm api client""" + + model_config = SettingsConfigDict(env_prefix="SLURM_CLIENT_") + host: str + username: str + password: SecretStr + access_token: SecretStr + + def create_api_client(self) -> SlurmApi: + """Create an api client using settings""" + config = Config( + host=self.host, + password=self.password.get_secret_value(), + username=self.username, + access_token=self.access_token.get_secret_value(), + ) + slurm = SlurmApi(Client(config)) + slurm.api_client.set_default_header( + header_name="X-SLURM-USER-NAME", + header_value=self.username, + ) + slurm.api_client.set_default_header( + header_name="X-SLURM-USER-PASSWORD", + header_value=self.password.get_secret_value(), + ) + slurm.api_client.set_default_header( + header_name="X-SLURM-USER-TOKEN", + header_value=self.access_token.get_secret_value(), + ) + return slurm + + +class DefaultSlurmSettings(BaseSettings): + """Configurations with default values or expected to be pulled from env + vars.""" + + model_config = SettingsConfigDict(env_prefix="SLURM_") + log_path: str + partition: str + name: str = Field( + default_factory=lambda: ( + f"job" + f"_{str(int(datetime.utcnow().timestamp()))}" + f"_{str(uuid4())[0:5]}" + ) + ) + qos: str = Field(default="dev") + environment: Dict[str, str] = Field( + default={ + "PATH": "/bin:/usr/bin/:/usr/local/bin/", + "LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib", + } + ) + memory_per_node: int = Field(default=50000) + tasks: int = Field(default=1) + minimum_cpus_per_node: int = Field(default=1) + nodes: List[int] = Field(default=[1, 1]) + time_limit: int = Field(default=360) + + +class JobState(str, Enum): + """The possible job_state values in the V0036JobsResponse class. The enums + don't appear to be importable from the aind-slurm-rest api.""" + + # Job terminated due to launch failure, typically due to a hardware failure + # (e.g. unable to boot the node or block and the job can not be + # requeued). + BF = "BOOT_FAIL" + + # Job was explicitly cancelled by the user or system administrator. The job + # may or may not have been initiated. + CA = "CANCELLED" + + # Job has terminated all processes on all nodes with an exit code of zero. + CD = "COMPLETED" + + # Job has been allocated resources, but are waiting for them to become + # ready for use (e.g. booting). + CF = "CONFIGURING" + + # Job is in the process of completing. Some processes on some nodes may + # still be active. + CG = "COMPLETING" + + # Job terminated on deadline. + DL = "DEADLINE" + + # Job terminated with non-zero exit code or other failure condition. + F = "FAILED" + + # Job terminated due to failure of one or more allocated nodes. + NF = "NODE_FAIL" + + # Job experienced out of memory error. + OOM = "OUT_OF_MEMORY" + + # Job is awaiting resource allocation. + PD = "PENDING" + + # Job terminated due to preemption. + PR = "PREEMPTED" + + # Job currently has an allocation. + R = "RUNNING" + + # Job is being held after requested reservation was deleted. + RD = "RESV_DEL_HOLD" + + # Job is being requeued by a federation. + RF = "REQUEUE_FED" + + # Held job is being requeued. + RH = "REQUEUE_HOLD" + + # Completing job is being requeued. + RQ = "REQUEUED" + + # Job is about to change size. + RS = "RESIZING" + + # Sibling was removed from cluster due to other cluster starting the job. + RV = "REVOKED" + + # Job is being signaled. + SI = "SIGNALING" + + # The job was requeued in a special state. This state can be set by users, + # typically in EpilogSlurmctld, if the job has terminated with a particular + # exit value. + SE = "SPECIAL_EXIT" + + # Job is staging out files. + SO = "STAGE_OUT" + + # Job has an allocation, but execution has been stopped with SIGSTOP + # signal. CPUS have been retained by this job. + ST = "STOPPED" + + # Job has an allocation, but execution has been suspended and CPUs have + # been released for other jobs. + S = "SUSPENDED" + + # Job terminated upon reaching its time limit. + TO = "TIMEOUT" + + FINISHED_CODES = [BF, CA, CD, DL, F, NF, OOM, PR, RS, RV, SE, ST, S, TO] + + +class SubmitSlurmJob: + """Main class to handle submitting and monitoring a slurm job""" + + def __init__( + self, + slurm: SlurmApi, + job_properties: V0036JobProperties, + script: str, + poll_job_interval: int = 120, + ): + """ + Class constructor + Parameters + ---------- + slurm : SlurmApi + job_properties : V0036JobProperties + script : str + poll_job_interval : int + Number of seconds to wait before checking slurm job status. + Default is 120. + """ + self.slurm = slurm + self.job_properties = job_properties + self._set_default_job_props(self.job_properties) + self.script = script + self.polling_request_sleep = poll_job_interval + + @staticmethod + def _set_default_job_props(job_properties: V0036JobProperties) -> None: + """ + Set default values for the slurm job if they are not explicitly set + in the job_properties. + Parameters + ---------- + job_properties : V0036JobProperties + The job_properties used to initially construct the class. + """ + # Check if any default values need to be set + basic_attributes_to_check = [ + "name", + "memory_per_node", + "tasks", + "minimum_cpus_per_node", + "nodes", + "time_limit", + "qos", + "partition", + ] + for attribute in basic_attributes_to_check: + if getattr(job_properties, attribute) is None: + setattr( + job_properties, + attribute, + getattr(DefaultSlurmSettings(), attribute), + ) + if ( + job_properties.environment is None + or job_properties.environment == {} + ): + job_properties.environment = DefaultSlurmSettings().environment + if job_properties.standard_out is None: + job_properties.standard_out = str( + Path(DefaultSlurmSettings().log_path) + / f"{job_properties.name}.out" + ) + if job_properties.standard_error is None: + job_properties.standard_error = str( + Path(DefaultSlurmSettings().log_path) + / f"{job_properties.name}_error.out" + ) + + def _submit_job(self) -> V0036JobSubmissionResponse: + """Submit the job to the slurm cluster.""" + job_submission = V0036JobSubmission( + script=self.script, job=self.job_properties + ) + submit_response = self.slurm.slurmctld_submit_job_0( + v0036_job_submission=job_submission + ) + if submit_response.errors: + raise Exception( + f"There were errors submitting the job to slurm: " + f"{submit_response.errors}" + ) + return submit_response + + def _monitor_job( + self, submit_response: V0036JobSubmissionResponse + ) -> None: + """ + Monitor a job submitted to the slurm cluster. + Parameters + ---------- + submit_response : V0036JobSubmissionResponse + The initial job submission response. Used to extract the job_id. + + """ + + job_id = submit_response.job_id + job_name = self.job_properties.name + job_response = self.slurm.slurmctld_get_job_0(job_id=job_id) + errors = job_response.errors + start_time = ( + None if not job_response.jobs else job_response.jobs[0].start_time + ) + job_state = ( + None if not job_response.jobs else job_response.jobs[0].job_state + ) + message = json.dumps( + { + "job_id": job_id, + "job_name": job_name, + "job_state": job_state, + "start_time": start_time, + } + ) + logging.info(message) + while ( + job_state + and job_state not in JobState.FINISHED_CODES + and not errors + ): + sleep(self.polling_request_sleep) + job_response = self.slurm.slurmctld_get_job_0(job_id=job_id) + errors = job_response.errors + start_time = ( + None + if not job_response.jobs + else job_response.jobs[0].start_time + ) + job_state = ( + None + if not job_response.jobs + else job_response.jobs[0].job_state + ) + message = json.dumps( + { + "job_id": job_id, + "job_name": job_name, + "job_state": job_state, + "start_time": start_time, + } + ) + logging.info(message) + + if job_state != JobState.CD or errors: + message = json.dumps( + { + "job_id": job_id, + "job_name": job_name, + "job_state": job_state, + } + ) + raise Exception( + f"There were errors with the slurm job. " + f"Job: {message}. " + f"Errors: {errors}" + ) + else: + logging.info("Job is Finished!") + return None + + def run_job(self): + """Submit and monitor a job.""" + submit_response = self._submit_job() + self._monitor_job(submit_response=submit_response) + + @classmethod + def from_args(cls, system_args: List[str], slurm: SlurmApi): + """ + Create job from command line arguments + Parameters + ---------- + system_args : List[str] + slurm : SlurmApi + """ + parser = ArgumentParser() + parser.add_argument( + "--script-path", + type=str, + required=False, + help="Path to bash script for slurm to run", + ) + parser.add_argument( + "--script-encoded", + type=str, + required=False, + help="Bash script encoded as a hex string for slurm to run", + ) + parser.add_argument( + "--job-properties", + type=str, + required=True, + ) + job_args = parser.parse_args(system_args) + if job_args.script_path: + script_path = Path(job_args.script_path) + with open(script_path, "r") as f: + script = f.read() + elif job_args.script_encoded: + script = binascii.unhexlify(job_args.script_encoded).decode() + else: + raise AssertionError( + "Either script-path or script-encoded is needed" + ) + + job_properties_json = job_args.job_properties + job_properties = V0036JobProperties.model_validate_json( + job_properties_json + ) + return cls(script=script, job_properties=job_properties, slurm=slurm) + + +if __name__ == "__main__": + + sys_args = sys.argv[1:] + slurm_client_settings = SlurmClientSettings() + main_slurm = slurm_client_settings.create_api_client() + main_slurm_job = SubmitSlurmJob.from_args( + system_args=sys_args, slurm=main_slurm + ) + main_slurm_job.run_job() diff --git a/tests/resources/test_slurm_script.txt b/tests/resources/test_slurm_script.txt new file mode 100644 index 0000000..575e03c --- /dev/null +++ b/tests/resources/test_slurm_script.txt @@ -0,0 +1,2 @@ +#!/bin/bash +echo 'Hello World?' && sleep 120 && echo 'Goodbye!' \ No newline at end of file diff --git a/tests/test_example.py b/tests/test_example.py deleted file mode 100644 index 06e9e0d..0000000 --- a/tests/test_example.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Example test template.""" - -import unittest - - -class ExampleTest(unittest.TestCase): - """Example Test Class""" - - def test_assert_example(self): - """Example of how to test the truth of a statement.""" - - self.assertTrue(1 == 1) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_submit_slurm_job.py b/tests/test_submit_slurm_job.py new file mode 100644 index 0000000..fde08f8 --- /dev/null +++ b/tests/test_submit_slurm_job.py @@ -0,0 +1,475 @@ +"""Tests methods in the submit_slurm_jobs module.""" + +import binascii +import os +import unittest +from pathlib import Path +from unittest.mock import MagicMock, call, patch + +from aind_slurm_rest import ( + V0036Error, + V0036JobResponseProperties, + V0036JobsResponse, + V0036JobSubmissionResponse, +) +from aind_slurm_rest.models.v0036_job_properties import V0036JobProperties + +from aind_airflow_jobs.submit_slurm_job import ( + JobState, + SlurmClientSettings, + SubmitSlurmJob, +) + +TEST_DIR = Path(os.path.dirname(os.path.realpath(__file__))) / "resources" +EXAMPLE_SCRIPT = TEST_DIR / "test_slurm_script.txt" + + +class TestSubmitSlurmJob(unittest.TestCase): + """Test methods in the SubmitSlurmJob class""" + + EXAMPLE_ENV_VAR = { + "SLURM_CLIENT_HOST": "slurm", + "SLURM_CLIENT_USERNAME": "username", + "SLURM_CLIENT_PASSWORD": "password", + "SLURM_CLIENT_ACCESS_TOKEN": "abc-123", + "SLURM_LOG_PATH": "/a/dir/to/write/logs/to", + "SLURM_PARTITION": "some_part", + } + + @patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True) + def test_default_job_properties(self): + """Tests that default job properties are set correctly.""" + slurm_client_settings = SlurmClientSettings() + job_properties = V0036JobProperties(environment={}) + script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) + slurm = slurm_client_settings.create_api_client() + slurm_job = SubmitSlurmJob( + slurm=slurm, + script=script, + job_properties=job_properties, + ) + self.assertEqual("some_part", slurm_job.job_properties.partition) + self.assertEqual("dev", slurm_job.job_properties.qos) + self.assertTrue(slurm_job.job_properties.name.startswith("job_")) + self.assertTrue( + slurm_job.job_properties.standard_out.startswith( + "/a/dir/to/write/logs/to/job_" + ) + ) + self.assertTrue(slurm_job.job_properties.standard_out.endswith(".out")) + self.assertTrue( + slurm_job.job_properties.standard_error.startswith( + "/a/dir/to/write/logs/to/job_" + ) + ) + self.assertTrue( + slurm_job.job_properties.standard_error.endswith("_error.out") + ) + self.assertEqual( + { + "PATH": "/bin:/usr/bin/:/usr/local/bin/", + "LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib", + }, + slurm_job.job_properties.environment, + ) + self.assertEqual(360, slurm_job.job_properties.time_limit) + self.assertEqual(50000, slurm_job.job_properties.memory_per_node) + self.assertEqual([1, 1], slurm_job.job_properties.nodes) + self.assertEqual(1, slurm_job.job_properties.tasks) + + @patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True) + def test_mixed_job_properties(self): + """Tests that job properties are not overwritten.""" + slurm_client_settings = SlurmClientSettings() + job_properties = V0036JobProperties( + environment={}, + name="my_job", + partition="part2", + qos="prod", + time_limit=5, + memory_per_node=500, + ) + script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) + slurm = slurm_client_settings.create_api_client() + slurm_job = SubmitSlurmJob( + slurm=slurm, + script=script, + job_properties=job_properties, + ) + self.assertEqual("part2", slurm_job.job_properties.partition) + self.assertEqual("prod", slurm_job.job_properties.qos) + self.assertTrue("my_job", slurm_job.job_properties.name) + self.assertEqual( + "/a/dir/to/write/logs/to/my_job.out", + slurm_job.job_properties.standard_out, + ) + self.assertEqual( + "/a/dir/to/write/logs/to/my_job_error.out", + slurm_job.job_properties.standard_error, + ) + self.assertEqual( + { + "PATH": "/bin:/usr/bin/:/usr/local/bin/", + "LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib", + }, + slurm_job.job_properties.environment, + ) + self.assertEqual(5, slurm_job.job_properties.time_limit) + self.assertEqual(500, slurm_job.job_properties.memory_per_node) + self.assertEqual([1, 1], slurm_job.job_properties.nodes) + self.assertEqual(1, slurm_job.job_properties.tasks) + + @patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True) + @patch("aind_slurm_rest.api.slurm_api.SlurmApi.slurmctld_submit_job_0") + def test_submit_job_with_errors(self, mock_submit_job: MagicMock): + """Tests that an exception is raised if there are errors in the + SubmitJobResponse""" + + mock_submit_job.return_value = V0036JobSubmissionResponse( + errors=[V0036Error(error="An error occurred.")] + ) + slurm_client_settings = SlurmClientSettings() + job_properties = V0036JobProperties(environment={}) + script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) + slurm = slurm_client_settings.create_api_client() + slurm_job = SubmitSlurmJob( + slurm=slurm, + script=script, + job_properties=job_properties, + ) + with self.assertRaises(Exception) as e: + slurm_job._submit_job() + expected_errors = ( + "There were errors submitting the job to slurm: " + "[V0036Error(error='An error occurred.', errno=None)]" + ) + self.assertEqual(expected_errors, e.exception.args[0]) + + @patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True) + @patch("aind_slurm_rest.api.slurm_api.SlurmApi.slurmctld_submit_job_0") + def test_submit_job(self, mock_submit_job: MagicMock): + """Tests that job is submitted successfully""" + + mock_submit_job.return_value = V0036JobSubmissionResponse( + errors=[], job_id=12345 + ) + slurm_client_settings = SlurmClientSettings() + job_properties = V0036JobProperties(environment={}) + script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) + slurm = slurm_client_settings.create_api_client() + slurm_job = SubmitSlurmJob( + slurm=slurm, + script=script, + job_properties=job_properties, + ) + response = slurm_job._submit_job() + expected_response = V0036JobSubmissionResponse(errors=[], job_id=12345) + self.assertEqual(expected_response, response) + + @patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True) + @patch("aind_slurm_rest.api.slurm_api.SlurmApi.slurmctld_get_job_0") + @patch("aind_airflow_jobs.submit_slurm_job.sleep", return_value=None) + @patch("logging.info") + def test_monitor_job( + self, + mock_log_info: MagicMock, + mock_sleep: MagicMock, + mock_get_job: MagicMock, + ): + """Tests that job is monitored successfully""" + + submit_job_response = V0036JobSubmissionResponse( + errors=[], job_id=12345 + ) + + submit_time = 1693788246 + start_time = 1693788400 + + mock_get_job.side_effect = [ + V0036JobsResponse( + errors=[], + jobs=[ + V0036JobResponseProperties( + job_state=JobState.PD.value, submit_time=submit_time + ) + ], + ), + V0036JobsResponse( + errors=[], + jobs=[ + V0036JobResponseProperties( + job_state=JobState.R.value, + submit_time=submit_time, + start_time=start_time, + ) + ], + ), + V0036JobsResponse( + errors=[], + jobs=[ + V0036JobResponseProperties( + job_state=JobState.CD.value, + submit_time=submit_time, + start_time=start_time, + ) + ], + ), + ] + slurm_client_settings = SlurmClientSettings() + job_properties = V0036JobProperties(environment={}, name="mock_job") + script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) + slurm = slurm_client_settings.create_api_client() + slurm_job = SubmitSlurmJob( + slurm=slurm, + script=script, + job_properties=job_properties, + ) + slurm_job._monitor_job(submit_response=submit_job_response) + + mock_sleep.assert_has_calls([call(120), call(120)]) + + mock_log_info.assert_has_calls( + [ + call( + '{"job_id": 12345, "job_name": "mock_job", ' + '"job_state": "PENDING", "start_time": null}' + ), + call( + '{"job_id": 12345, "job_name": "mock_job", ' + '"job_state": "RUNNING", "start_time": 1693788400}' + ), + call( + '{"job_id": 12345, "job_name": "mock_job", ' + '"job_state": "COMPLETED", "start_time": 1693788400}' + ), + call("Job is Finished!"), + ] + ) + + @patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True) + @patch("aind_slurm_rest.api.slurm_api.SlurmApi.slurmctld_get_job_0") + @patch("aind_airflow_jobs.submit_slurm_job.sleep", return_value=None) + @patch("logging.info") + def test_monitor_job_with_errors( + self, + mock_log_info: MagicMock, + mock_sleep: MagicMock, + mock_get_job: MagicMock, + ): + """Tests that errors are raised if response has errors.""" + + submit_job_response = V0036JobSubmissionResponse( + errors=[], job_id=12345 + ) + + submit_time = 1693788246 + + mock_get_job.side_effect = [ + V0036JobsResponse( + errors=[V0036Error(error="An error occurred.")], + jobs=[ + V0036JobResponseProperties( + job_state=JobState.F.value, submit_time=submit_time + ) + ], + ) + ] + slurm_client_settings = SlurmClientSettings() + job_properties = V0036JobProperties(environment={}, name="mock_job") + script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) + slurm = slurm_client_settings.create_api_client() + slurm_job = SubmitSlurmJob( + slurm=slurm, + script=script, + job_properties=job_properties, + ) + with self.assertRaises(Exception) as e: + slurm_job._monitor_job(submit_response=submit_job_response) + + expected_error_message = ( + "There were errors with the slurm job. Job: " + '{"job_id": 12345, "job_name": "mock_job", "job_state": "FAILED"}.' + " Errors: [V0036Error(error='An error occurred.', errno=None)]" + ) + + self.assertEqual(expected_error_message, e.exception.args[0]) + mock_sleep.assert_not_called() + mock_log_info.assert_has_calls( + [ + call( + '{"job_id": 12345, "job_name": "mock_job", ' + '"job_state": "FAILED", "start_time": null}' + ) + ] + ) + + @patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True) + @patch("aind_slurm_rest.api.slurm_api.SlurmApi.slurmctld_get_job_0") + @patch("aind_airflow_jobs.submit_slurm_job.sleep", return_value=None) + @patch("logging.info") + def test_monitor_job_with_fail_code( + self, + mock_log_info: MagicMock, + mock_sleep: MagicMock, + mock_get_job: MagicMock, + ): + """Tests that errors are raised if response has an error code.""" + + submit_job_response = V0036JobSubmissionResponse( + errors=[], job_id=12345 + ) + + submit_time = 1693788246 + start_time = 1693788400 + + mock_get_job.side_effect = [ + V0036JobsResponse( + errors=[], + jobs=[ + V0036JobResponseProperties( + job_state=JobState.PD.value, submit_time=submit_time + ) + ], + ), + V0036JobsResponse( + errors=[], + jobs=[ + V0036JobResponseProperties( + job_state=JobState.R.value, + submit_time=submit_time, + start_time=start_time, + ) + ], + ), + V0036JobsResponse( + errors=[], + jobs=[ + V0036JobResponseProperties( + job_state=JobState.F.value, + submit_time=submit_time, + start_time=start_time, + ) + ], + ), + ] + slurm_client_settings = SlurmClientSettings() + job_properties = V0036JobProperties(environment={}, name="mock_job") + script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) + slurm = slurm_client_settings.create_api_client() + slurm_job = SubmitSlurmJob( + slurm=slurm, + script=script, + job_properties=job_properties, + ) + with self.assertRaises(Exception) as e: + slurm_job._monitor_job(submit_response=submit_job_response) + + expected_error_message = ( + "There were errors with the slurm job. Job: " + '{"job_id": 12345, "job_name": "mock_job", "job_state": "FAILED"}.' + " Errors: []" + ) + + self.assertEqual(expected_error_message, e.exception.args[0]) + mock_sleep.assert_has_calls([call(120), call(120)]) + mock_log_info.assert_has_calls( + [ + call( + '{"job_id": 12345, "job_name": "mock_job", ' + '"job_state": "PENDING", "start_time": null}' + ), + call( + '{"job_id": 12345, "job_name": "mock_job", ' + '"job_state": "RUNNING", "start_time": 1693788400}' + ), + call( + '{"job_id": 12345, "job_name": "mock_job", ' + '"job_state": "FAILED", "start_time": 1693788400}' + ), + ] + ) + + @patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True) + @patch("aind_airflow_jobs.submit_slurm_job.SubmitSlurmJob._submit_job") + @patch("aind_airflow_jobs.submit_slurm_job.SubmitSlurmJob._monitor_job") + def test_run_job(self, mock_monitor: MagicMock, mock_submit: MagicMock): + """Tests that run_job calls right methods.""" + slurm_client_settings = SlurmClientSettings() + job_properties = V0036JobProperties(environment={}) + script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"]) + slurm = slurm_client_settings.create_api_client() + slurm_job = SubmitSlurmJob( + slurm=slurm, + script=script, + job_properties=job_properties, + ) + + slurm_job.run_job() + mock_submit.assert_called() + mock_monitor.assert_called() + + @patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True) + def test_from_args_script_path(self): + """Tests that a job args can be input via the command line.""" + slurm_client_settings = SlurmClientSettings() + slurm = slurm_client_settings.create_api_client() + job_properties_json = V0036JobProperties( + environment={} + ).model_dump_json() + sys_args = [ + "--script-path", + str(EXAMPLE_SCRIPT), + "--job-properties", + job_properties_json, + ] + job = SubmitSlurmJob.from_args(slurm=slurm, system_args=sys_args) + expected_script = ( + "#!/bin/bash\n" + "echo 'Hello World?' && sleep 120 && echo 'Goodbye!'" + ) + self.assertEqual("some_part", job.job_properties.partition) + self.assertEqual(expected_script, job.script) + + @patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True) + def test_from_args_script_encoded(self): + """Tests that a job args can be input via the command line using an + encoded script.""" + expected_script = ( + "#!/bin/bash\n" + "echo 'Hello World?' && sleep 120 && echo 'Goodbye!'" + ) + script_encoded = binascii.hexlify(expected_script.encode()).decode() + slurm_client_settings = SlurmClientSettings() + slurm = slurm_client_settings.create_api_client() + job_properties_json = V0036JobProperties( + environment={} + ).model_dump_json() + sys_args = [ + "--script-encoded", + script_encoded, + "--job-properties", + job_properties_json, + ] + job = SubmitSlurmJob.from_args(slurm=slurm, system_args=sys_args) + self.assertEqual("some_part", job.job_properties.partition) + self.assertEqual(expected_script, job.script) + + @patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True) + def test_from_args_error(self): + """Tests that an error is raised if no script is set""" + slurm_client_settings = SlurmClientSettings() + slurm = slurm_client_settings.create_api_client() + job_properties_json = V0036JobProperties( + environment={} + ).model_dump_json() + sys_args = ["--job-properties", job_properties_json] + with self.assertRaises(Exception) as e: + SubmitSlurmJob.from_args(slurm=slurm, system_args=sys_args) + self.assertEqual( + "Either script-path or script-encoded is needed", + e.exception.args[0], + ) + + +if __name__ == "__main__": + unittest.main()