Skip to content

Commit

Permalink
Replace sys args with values from files (#122)
Browse files Browse the repository at this point in the history
* implement replace_sys_args_with_values_from_files() and load_value_from_file()

* call replace_sys_args_with_values_from_files() in main entry scripts
  • Loading branch information
ArneBinder authored Aug 30, 2023
1 parent fae1a01 commit 8e26c0d
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,6 @@ def main(cfg: DictConfig) -> None:


if __name__ == "__main__":
utils.replace_sys_args_with_values_from_files()
utils.prepare_omegaconf()
main()
1 change: 1 addition & 0 deletions src/evaluate_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,6 @@ def main(cfg: DictConfig) -> Any:


if __name__ == "__main__":
utils.replace_sys_args_with_values_from_files()
utils.prepare_omegaconf()
main()
1 change: 1 addition & 0 deletions src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,5 +132,6 @@ def main(cfg: DictConfig) -> None:


if __name__ == "__main__":
utils.replace_sys_args_with_values_from_files()
utils.prepare_omegaconf()
main()
1 change: 1 addition & 0 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,5 +229,6 @@ def main(cfg: DictConfig) -> Optional[float]:


if __name__ == "__main__":
utils.replace_sys_args_with_values_from_files()
utils.prepare_omegaconf()
main()
2 changes: 1 addition & 1 deletion src/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .config_utils import execute_pipeline, instantiate_dict_entries, prepare_omegaconf
from .logging_utils import close_loggers, get_pylogger, log_hyperparameters
from .rich_utils import enforce_tags, print_config_tree
from .task_utils import extras, save_file, task_wrapper
from .task_utils import extras, replace_sys_args_with_values_from_files, save_file, task_wrapper
93 changes: 92 additions & 1 deletion src/utils/task_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
import os
import sys
import time
import warnings
from pathlib import Path
from typing import Callable
from typing import Callable, Dict

from omegaconf import DictConfig
from pytorch_lightning.utilities import rank_zero_only
Expand Down Expand Up @@ -85,3 +88,91 @@ def save_file(path: str, content: str) -> None:
"""Save file in rank zero mode (only on one process in multi-GPU setup)."""
with open(path, "w+") as file:
file.write(content)


def load_value_from_file(path: str, split_path_key: str = ":", split_key_parts: str = "/") -> Dict:
"""Load a value from a file. The path can point to elements within the file (see split_path_key
parameter) and that can be nested (see split_key_parts parameter). For now, only .json files
are supported.
Args:
path: path to the file (and data within the file)
split_path_key: split the path on this value to get the path to the file and the key within the file
split_key_parts: the value to split the key on to get the nested keys
"""

parts_path = path.split(split_path_key, maxsplit=1)
file_extension = os.path.splitext(parts_path[0])[1]
if file_extension == ".json":
with open(parts_path[0], "r") as f:
data = json.load(f)
else:
raise ValueError(f"Expected .json file, got {file_extension}")

if len(parts_path) == 1:
return data

keys = parts_path[1].split(split_key_parts)
for key in keys:
data = data[key]
return data


def replace_sys_args_with_values_from_files(
load_prefix: str = "LOAD_ARG:",
load_multi_prefix: str = "LOAD_MULTI_ARG:",
**load_value_from_file_kwargs,
) -> None:
"""Replaces arguments in sys.argv with values loaded from files.
Examples:
# config.json contains {"a": 1, "b": 2}
python train.py LOAD_ARG:job_return_value.json
# this will pass "{a:1,b:2}" as the first argument to train.py
# config.json contains [1, 2, 3]
python train.py LOAD_MULTI_ARG:job_return_value.json
# this will pass "1,2,3" as the first argument to train.py
# config.json contains {"model": {"ouput_dir": ["path1", "path2"], f1: [0.7, 0.6]}}
python train.py load_model=LOAD_ARG:job_return_value.json:model/output_dir
# this will pass "load_model=path1,path2" to train.py
Args:
load_prefix: the prefix to use for loading a single value from a file
load_multi_prefix: the prefix to use for loading a list of values from a file
**load_value_from_file_kwargs: additional kwargs to pass to load_value_from_file
"""

updated_args = []
for arg in sys.argv[1:]:
is_multirun_arg = False
if load_prefix in arg:
parts = arg.split(load_prefix, maxsplit=1)
elif load_multi_prefix in arg:
parts = arg.split(load_multi_prefix, maxsplit=1)
is_multirun_arg = True
else:
updated_args.append(arg)
continue
if len(parts) == 2:
log.warning(f'Replacing argument value for "{parts[0]}" with content from {parts[1]}')
json_value = load_value_from_file(parts[1], **load_value_from_file_kwargs)
json_value_str = json.dumps(json_value)
# replace quotes and spaces
json_value_str = json_value_str.replace('"', "").replace(" ", "")
# remove outer brackets
if is_multirun_arg:
if not isinstance(json_value, list):
raise ValueError(
f"Expected list for multirun argument, got {type(json_value)}. If you just want "
f"to set a single value, use {load_prefix} instead of {load_multi_prefix}."
)
json_value_str = json_value_str[1:-1]
# add outer quotes
modified_arg = f"{parts[0]}{json_value_str}"
updated_args.append(modified_arg)
else:
updated_args.append(arg)
# Set sys.argv to the updated arguments
sys.argv = [sys.argv[0]] + updated_args

0 comments on commit 8e26c0d

Please sign in to comment.