Skip to content

Commit

Permalink
Merge same env_vars when creating array wrapper. (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanluoyc authored Feb 5, 2024
1 parent c8097dc commit c5c4c27
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
21 changes: 21 additions & 0 deletions lxm3/xm_cluster/execution/job_script.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import collections
import os
import re
from typing import (
Expand Down Expand Up @@ -214,10 +215,30 @@ def _create_env_vars(env_vars_list: List[Dict[str, str]]) -> str:
first_keys = set(env_vars_list[0].keys())
if not first_keys:
return ""

# Find out keys that are common to all environment variables
var_to_values = collections.defaultdict(list)
for env in env_vars_list:
for k, v in env.items():
var_to_values[k].append(v)

common_keys = []
for k, v in var_to_values.items():
if len(set(v)) == 1:
common_keys.append(k)
common_keys = sorted(common_keys)

for env_vars in env_vars_list:
if first_keys != set(env_vars.keys()):
raise ValueError("Expect all environment variables to have the same keys")

# Generate shared environment variables
for k in sorted(common_keys):
lines.append('export {key}="{value}"'.format(key=k, value=env_vars_list[0][k]))

for key in first_keys:
if key in common_keys:
continue
for task_id, env_vars in enumerate(env_vars_list):
lines.append(
'{key}_{task_id}="{value}"'.format(
Expand Down
12 changes: 12 additions & 0 deletions tests/execution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ def test_env_vars(self):
def test_empty_env_vars(self):
self.assertEqual(job_script._create_env_vars([{}]), "")

def test_common_values(self):
env_var_str = job_script._create_env_vars(
[{"FOO": "BAR", "BAR": "1"}, {"FOO": "BAR", "BAR": "2"}]
)
expected = """\
export FOO="BAR"
BAR_0="1"
BAR_1="2"
BAR=$(eval echo \\$"BAR_$LXM_TASK_ID")
export BAR"""
self.assertEqual(env_var_str, expected)

def test_different_keys(self):
with self.assertRaises(ValueError):
job_script._create_env_vars([{"FOO": "BAR1"}, {"BAR": "BAR2"}])
Expand Down

0 comments on commit c5c4c27

Please sign in to comment.