Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hot fix remove defaults #4

Merged
merged 3 commits into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 2 additions & 77 deletions src/aind_airflow_jobs/submit_slurm_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,18 @@
import logging
import sys
from argparse import ArgumentParser
from datetime import datetime
from enum import Enum
from pathlib import Path
from time import sleep
from typing import Dict, List
from uuid import uuid4
from typing import List

from aind_slurm_rest import ApiClient as Client
from aind_slurm_rest import Configuration as Config
from aind_slurm_rest import V0036JobSubmissionResponse
from aind_slurm_rest.api.slurm_api import SlurmApi
from aind_slurm_rest.models.v0036_job_properties import V0036JobProperties
from aind_slurm_rest.models.v0036_job_submission import V0036JobSubmission
from pydantic import Field, SecretStr
from pydantic import SecretStr
from pydantic_settings import BaseSettings, SettingsConfigDict

logging.basicConfig(level="INFO")
Expand Down Expand Up @@ -57,34 +55,6 @@ def create_api_client(self) -> SlurmApi:
return slurm


class DefaultSlurmSettings(BaseSettings):
"""Configurations with default values or expected to be pulled from env
vars."""

model_config = SettingsConfigDict(env_prefix="SLURM_")
log_path: str
partition: str
name: str = Field(
default_factory=lambda: (
f"job"
f"_{str(int(datetime.utcnow().timestamp()))}"
f"_{str(uuid4())[0:5]}"
)
)
qos: str = Field(default="dev")
environment: Dict[str, str] = Field(
default={
"PATH": "/bin:/usr/bin/:/usr/local/bin/",
"LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib",
}
)
memory_per_node: int = Field(default=50000)
tasks: int = Field(default=1)
minimum_cpus_per_node: int = Field(default=1)
nodes: List[int] = Field(default=[1, 1])
time_limit: int = Field(default=360)


class JobState(str, Enum):
"""The possible job_state values in the V0036JobsResponse class. The enums
don't appear to be importable from the aind-slurm-rest api."""
Expand Down Expand Up @@ -196,54 +166,9 @@ def __init__(
"""
self.slurm = slurm
self.job_properties = job_properties
self._set_default_job_props(self.job_properties)
self.script = script
self.polling_request_sleep = poll_job_interval

@staticmethod
def _set_default_job_props(job_properties: V0036JobProperties) -> None:
"""
Set default values for the slurm job if they are not explicitly set
in the job_properties.
Parameters
----------
job_properties : V0036JobProperties
The job_properties used to initially construct the class.
"""
# Check if any default values need to be set
basic_attributes_to_check = [
"name",
"memory_per_node",
"tasks",
"minimum_cpus_per_node",
"nodes",
"time_limit",
"qos",
"partition",
]
for attribute in basic_attributes_to_check:
if getattr(job_properties, attribute) is None:
setattr(
job_properties,
attribute,
getattr(DefaultSlurmSettings(), attribute),
)
if (
job_properties.environment is None
or job_properties.environment == {}
):
job_properties.environment = DefaultSlurmSettings().environment
if job_properties.standard_out is None:
job_properties.standard_out = str(
Path(DefaultSlurmSettings().log_path)
/ f"{job_properties.name}.out"
)
if job_properties.standard_error is None:
job_properties.standard_error = str(
Path(DefaultSlurmSettings().log_path)
/ f"{job_properties.name}_error.out"
)

def _submit_job(self) -> V0036JobSubmissionResponse:
"""Submit the job to the slurm cluster."""
job_submission = V0036JobSubmission(
Expand Down
171 changes: 114 additions & 57 deletions tests/test_submit_slurm_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,24 @@ class TestSubmitSlurmJob(unittest.TestCase):
"SLURM_CLIENT_USERNAME": "username",
"SLURM_CLIENT_PASSWORD": "password",
"SLURM_CLIENT_ACCESS_TOKEN": "abc-123",
"SLURM_LOG_PATH": "/a/dir/to/write/logs/to",
"SLURM_PARTITION": "some_part",
}

@patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True)
def test_default_job_properties(self):
"""Tests that default job properties are set correctly."""
slurm_client_settings = SlurmClientSettings()
job_properties = V0036JobProperties(environment={})
job_properties = V0036JobProperties(
environment={
"LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib",
"PATH": "/bin:/usr/bin/:/usr/local/bin/",
},
partition="some_part",
standard_error="/a/dir/to/write/logs/to/job_123_error.out",
standard_out="/a/dir/to/write/logs/to/job_123.out",
qos="dev",
name="job_123",
time_limit=360,
)
script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"])
slurm = slurm_client_settings.create_api_client()
slurm_job = SubmitSlurmJob(
Expand Down Expand Up @@ -73,51 +82,6 @@ def test_default_job_properties(self):
slurm_job.job_properties.environment,
)
self.assertEqual(360, slurm_job.job_properties.time_limit)
self.assertEqual(50000, slurm_job.job_properties.memory_per_node)
self.assertEqual([1, 1], slurm_job.job_properties.nodes)
self.assertEqual(1, slurm_job.job_properties.tasks)

@patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True)
def test_mixed_job_properties(self):
"""Tests that job properties are not overwritten."""
slurm_client_settings = SlurmClientSettings()
job_properties = V0036JobProperties(
environment={},
name="my_job",
partition="part2",
qos="prod",
time_limit=5,
memory_per_node=500,
)
script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"])
slurm = slurm_client_settings.create_api_client()
slurm_job = SubmitSlurmJob(
slurm=slurm,
script=script,
job_properties=job_properties,
)
self.assertEqual("part2", slurm_job.job_properties.partition)
self.assertEqual("prod", slurm_job.job_properties.qos)
self.assertTrue("my_job", slurm_job.job_properties.name)
self.assertEqual(
"/a/dir/to/write/logs/to/my_job.out",
slurm_job.job_properties.standard_out,
)
self.assertEqual(
"/a/dir/to/write/logs/to/my_job_error.out",
slurm_job.job_properties.standard_error,
)
self.assertEqual(
{
"PATH": "/bin:/usr/bin/:/usr/local/bin/",
"LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib",
},
slurm_job.job_properties.environment,
)
self.assertEqual(5, slurm_job.job_properties.time_limit)
self.assertEqual(500, slurm_job.job_properties.memory_per_node)
self.assertEqual([1, 1], slurm_job.job_properties.nodes)
self.assertEqual(1, slurm_job.job_properties.tasks)

@patch.dict(os.environ, EXAMPLE_ENV_VAR, clear=True)
@patch("aind_slurm_rest.api.slurm_api.SlurmApi.slurmctld_submit_job_0")
Expand All @@ -129,7 +93,18 @@ def test_submit_job_with_errors(self, mock_submit_job: MagicMock):
errors=[V0036Error(error="An error occurred.")]
)
slurm_client_settings = SlurmClientSettings()
job_properties = V0036JobProperties(environment={})
job_properties = V0036JobProperties(
environment={
"LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib",
"PATH": "/bin:/usr/bin/:/usr/local/bin/",
},
partition="some_part",
standard_error="/a/dir/to/write/logs/to/job_123_error.out",
standard_out="/a/dir/to/write/logs/to/job_123.out",
qos="dev",
name="job_123",
time_limit=360,
)
script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"])
slurm = slurm_client_settings.create_api_client()
slurm_job = SubmitSlurmJob(
Expand All @@ -154,7 +129,18 @@ def test_submit_job(self, mock_submit_job: MagicMock):
errors=[], job_id=12345
)
slurm_client_settings = SlurmClientSettings()
job_properties = V0036JobProperties(environment={})
job_properties = V0036JobProperties(
environment={
"LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib",
"PATH": "/bin:/usr/bin/:/usr/local/bin/",
},
partition="some_part",
standard_error="/a/dir/to/write/logs/to/job_123_error.out",
standard_out="/a/dir/to/write/logs/to/job_123.out",
qos="dev",
name="job_123",
time_limit=360,
)
script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"])
slurm = slurm_client_settings.create_api_client()
slurm_job = SubmitSlurmJob(
Expand Down Expand Up @@ -216,7 +202,18 @@ def test_monitor_job(
),
]
slurm_client_settings = SlurmClientSettings()
job_properties = V0036JobProperties(environment={}, name="mock_job")
job_properties = V0036JobProperties(
environment={
"LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib",
"PATH": "/bin:/usr/bin/:/usr/local/bin/",
},
partition="some_part",
standard_error="/a/dir/to/write/logs/to/job_123_error.out",
standard_out="/a/dir/to/write/logs/to/job_123.out",
qos="dev",
name="mock_job",
time_limit=360,
)
script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"])
slurm = slurm_client_settings.create_api_client()
slurm_job = SubmitSlurmJob(
Expand Down Expand Up @@ -275,7 +272,18 @@ def test_monitor_job_with_errors(
)
]
slurm_client_settings = SlurmClientSettings()
job_properties = V0036JobProperties(environment={}, name="mock_job")
job_properties = V0036JobProperties(
environment={
"LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib",
"PATH": "/bin:/usr/bin/:/usr/local/bin/",
},
partition="some_part",
standard_error="/a/dir/to/write/logs/to/job_123_error.out",
standard_out="/a/dir/to/write/logs/to/job_123.out",
qos="dev",
name="mock_job",
time_limit=360,
)
script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"])
slurm = slurm_client_settings.create_api_client()
slurm_job = SubmitSlurmJob(
Expand Down Expand Up @@ -353,7 +361,18 @@ def test_monitor_job_with_fail_code(
),
]
slurm_client_settings = SlurmClientSettings()
job_properties = V0036JobProperties(environment={}, name="mock_job")
job_properties = V0036JobProperties(
environment={
"LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib",
"PATH": "/bin:/usr/bin/:/usr/local/bin/",
},
partition="some_part",
standard_error="/a/dir/to/write/logs/to/job_123_error.out",
standard_out="/a/dir/to/write/logs/to/job_123.out",
qos="dev",
name="mock_job",
time_limit=360,
)
script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"])
slurm = slurm_client_settings.create_api_client()
slurm_job = SubmitSlurmJob(
Expand Down Expand Up @@ -395,7 +414,18 @@ def test_monitor_job_with_fail_code(
def test_run_job(self, mock_monitor: MagicMock, mock_submit: MagicMock):
"""Tests that run_job calls right methods."""
slurm_client_settings = SlurmClientSettings()
job_properties = V0036JobProperties(environment={})
job_properties = V0036JobProperties(
environment={
"LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib",
"PATH": "/bin:/usr/bin/:/usr/local/bin/",
},
partition="some_part",
standard_error="/a/dir/to/write/logs/to/job_123_error.out",
standard_out="/a/dir/to/write/logs/to/job_123.out",
qos="dev",
name="job_123",
time_limit=360,
)
script = " ".join(["#!/bin/bash", "\necho", "'Hello World?'"])
slurm = slurm_client_settings.create_api_client()
slurm_job = SubmitSlurmJob(
Expand All @@ -414,7 +444,16 @@ def test_from_args_script_path(self):
slurm_client_settings = SlurmClientSettings()
slurm = slurm_client_settings.create_api_client()
job_properties_json = V0036JobProperties(
environment={}
environment={
"LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib",
"PATH": "/bin:/usr/bin/:/usr/local/bin/",
},
partition="some_part",
standard_error="/a/dir/to/write/logs/to/job_123_error.out",
standard_out="/a/dir/to/write/logs/to/job_123.out",
qos="dev",
name="job_123",
time_limit=360,
).model_dump_json()
sys_args = [
"--script-path",
Expand Down Expand Up @@ -442,7 +481,16 @@ def test_from_args_script_encoded(self):
slurm_client_settings = SlurmClientSettings()
slurm = slurm_client_settings.create_api_client()
job_properties_json = V0036JobProperties(
environment={}
environment={
"LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib",
"PATH": "/bin:/usr/bin/:/usr/local/bin/",
},
partition="some_part",
standard_error="/a/dir/to/write/logs/to/job_123_error.out",
standard_out="/a/dir/to/write/logs/to/job_123.out",
qos="dev",
name="job_123",
time_limit=360,
).model_dump_json()
sys_args = [
"--script-encoded",
Expand All @@ -460,7 +508,16 @@ def test_from_args_error(self):
slurm_client_settings = SlurmClientSettings()
slurm = slurm_client_settings.create_api_client()
job_properties_json = V0036JobProperties(
environment={}
environment={
"LD_LIBRARY_PATH": "/lib/:/lib64/:/usr/local/lib",
"PATH": "/bin:/usr/bin/:/usr/local/bin/",
},
partition="some_part",
standard_error="/a/dir/to/write/logs/to/job_123_error.out",
standard_out="/a/dir/to/write/logs/to/job_123.out",
qos="dev",
name="job_123",
time_limit=360,
).model_dump_json()
sys_args = ["--job-properties", job_properties_json]
with self.assertRaises(Exception) as e:
Expand Down
Loading