Skip to content

Commit

Permalink
Merge pull request #4 from Ciela-Institute/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
AlexandreAdam authored Jul 11, 2024
2 parents 02af79f + 55ed6d6 commit 2cbeed5
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 22 deletions.
23 changes: 10 additions & 13 deletions src/milex_scheduler/job_to_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
def create_slurm_script(job: dict, date: datetime, machine_config: dict) -> str:
"""Creates a SLURM script and saves it locally"""
user_settings = load_config()
path = os.path.join(user_settings['local']['path'], "slurm")
path = os.path.join(user_settings["local"]["path"], "slurm")
slurm_name = name_slurm_script(job, date)
with open(os.path.join(path, slurm_name), 'w') as f:
with open(os.path.join(path, slurm_name), "w") as f:
write_slurm_content(f, job, machine_config)
print(f"Saved SLURM script for job {job['name']} saved to {path}")
return slurm_name
Expand All @@ -22,21 +22,21 @@ def write_slurm_content(file: TextIOWrapper, job: dict, machine_config: dict) ->
"""
Writes the content of the SLURM script with formatted arguments, handling list arguments differently based on their type.
"""
env_command = machine_config.get('env_command', '')
slurm_account = machine_config.get('slurm_account', '')
env_command = machine_config.get("env_command", "")
slurm_account = machine_config.get("slurm_account", "")

file.write("#!/bin/bash\n")
if slurm_account:
file.write(f"#SBATCH --account={slurm_account}\n")
output_dir = os.path.join(machine_config['path'], "slurm")
output_dir = os.path.join(machine_config["path"], "slurm")
file.write(f"#SBATCH --output={os.path.join(output_dir, '%x-%j.out')}\n")
file.write(f"#SBATCH --job-name={job['name']}\n")

# SLURM directives
for key, value in job['slurm'].items():
for key, value in job["slurm"].items():
if value is not None:
file.write(f"#SBATCH --{key.replace('_', '-')}={value}\n")

# Make sure path is exported to environment
file.write(f"export MILEX=\"{machine_config['path']}\"\n")

Expand All @@ -50,11 +50,9 @@ def write_slurm_content(file: TextIOWrapper, job: dict, machine_config: dict) ->

# Main command and arguments
file.write(f"{job['script']} \\\n")
args_keys = list(job['args'].keys())
if args_keys:
last_arg = args_keys[-1] # Get the last argument key in case there are arguments
job_args = job.get("script_args", {})

for k, v in job['args'].items():
for i, (k, v) in enumerate(job_args.items()):
if v is None:
continue
if isinstance(v, bool):
Expand All @@ -71,9 +69,8 @@ def write_slurm_content(file: TextIOWrapper, job: dict, machine_config: dict) ->
else:
arg_line = f" --{k}={v}"

if k != last_arg:
if i < len(job_args) - 1:
arg_line += " \\\n"
else:
arg_line += "\n"
file.write(arg_line)

8 changes: 4 additions & 4 deletions tests/integration/test_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@
"mem": "4G",
"time": "01:00:00",
},
"args": {"param1": "value1", "param2": "value2"},
"script_args": {"param1": "value1", "param2": "value2"},
},
"JobB": {
"name": "JobB",
"script": "run-job-b",
"dependencies": ["JobA"],
"slurm": {"tasks": 1, "cpus_per_task": 2, "mem": "8G", "time": "02:00:00"},
"args": {"param1": "value3", "param2": "value4"},
"script_args": {"param1": "value3", "param2": "value4"},
},
"JobC": {
"name": "JobC",
"script": "run-job-c",
"dependencies": ["JobA", "JobB"],
"slurm": {"tasks": 1, "cpus_per_task": 4, "mem": "16G", "time": "03:00:00"},
"args": {"param1": "value5", "param2": "value6"},
"script_args": {"param1": "value5", "param2": "value6"},
"pre_commands": ["echo 'Starting Job C'"],
},
}
Expand Down Expand Up @@ -220,7 +220,7 @@ def test_integration_schedule_jobs(
"mem": "4G",
"time": "01:00:00",
},
"args": {"param1": "value1", "param2": "value2"},
"script_args": {"param1": "value1", "param2": "value2"},
},
}

Expand Down
10 changes: 5 additions & 5 deletions tests/unit/test_job_to_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_write_slurm_content(mock_load_config):
job = {
"name": "test_job",
"slurm": {"time": "01:00:00", "partition": "test-partition", "array": "1-10%6"},
"args": {"arg1": "value1", "arg2": [1, 2, 3]},
"script_args": {"arg1": "value1", "arg2": [1, 2, 3]},
"script": "test-application",
}
file = StringIO()
Expand Down Expand Up @@ -48,7 +48,7 @@ def test_write_slurm_boolean_flag(conditional_flag, expected_line, mock_load_con
job = {
"name": "boolean_flag_test",
"slurm": {},
"args": {"conditional": conditional_flag},
"script_args": {"conditional": conditional_flag},
"script": "test-boolean-application",
}
file = StringIO()
Expand All @@ -68,7 +68,7 @@ def test_write_slurm_with_none_value(mock_load_config):
job = {
"name": "none_value_test",
"slurm": {},
"args": {"arg_with_none": None}, # Test handling None value
"script_args": {"arg_with_none": None}, # Test handling None value
"script": "test-none-application",
}
file = StringIO()
Expand All @@ -88,7 +88,7 @@ def test_write_slurm_with_pre_commands_and_env_command(mock_load_config):
job = {
"name": "pre_commands_test",
"slurm": {},
"args": {},
"script_args": {},
"script": "test-pre-commands-application",
"pre_commands": ["module load python", "module load cuda"],
}
Expand All @@ -110,7 +110,7 @@ def test_write_slurm_output_dir_customization(mock_load_config):
job = {
"name": "output_dir_test",
"slurm": {"output": "custom-output-%j.txt"},
"args": {},
"script_args": {},
"script": "test-output-dir-application",
}
file = StringIO()
Expand Down

0 comments on commit 2cbeed5

Please sign in to comment.