diff --git a/src/evaluate.py b/src/evaluate.py index 5622aa2..dec9ed0 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -125,5 +125,6 @@ def main(cfg: DictConfig) -> None: if __name__ == "__main__": + utils.replace_sys_args_with_values_from_files() utils.prepare_omegaconf() main() diff --git a/src/evaluate_documents.py b/src/evaluate_documents.py index 5f12719..0df8586 100644 --- a/src/evaluate_documents.py +++ b/src/evaluate_documents.py @@ -106,5 +106,6 @@ def main(cfg: DictConfig) -> Any: if __name__ == "__main__": + utils.replace_sys_args_with_values_from_files() utils.prepare_omegaconf() main() diff --git a/src/predict.py b/src/predict.py index 522a3db..d05c8cd 100644 --- a/src/predict.py +++ b/src/predict.py @@ -132,5 +132,6 @@ def main(cfg: DictConfig) -> None: if __name__ == "__main__": + utils.replace_sys_args_with_values_from_files() utils.prepare_omegaconf() main() diff --git a/src/train.py b/src/train.py index 101841f..5536370 100644 --- a/src/train.py +++ b/src/train.py @@ -229,5 +229,6 @@ def main(cfg: DictConfig) -> Optional[float]: if __name__ == "__main__": + utils.replace_sys_args_with_values_from_files() utils.prepare_omegaconf() main() diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 6f9f927..a7b70da 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,4 +1,4 @@ from .config_utils import execute_pipeline, instantiate_dict_entries, prepare_omegaconf from .logging_utils import close_loggers, get_pylogger, log_hyperparameters from .rich_utils import enforce_tags, print_config_tree -from .task_utils import extras, save_file, task_wrapper +from .task_utils import extras, replace_sys_args_with_values_from_files, save_file, task_wrapper diff --git a/src/utils/task_utils.py b/src/utils/task_utils.py index 37a6aa7..3b8fd3d 100644 --- a/src/utils/task_utils.py +++ b/src/utils/task_utils.py @@ -1,7 +1,10 @@ +import json +import os +import sys import time import warnings from pathlib import Path -from typing import Callable +from typing import Callable, Dict from omegaconf import DictConfig from pytorch_lightning.utilities import rank_zero_only @@ -85,3 +88,91 @@ def save_file(path: str, content: str) -> None: """Save file in rank zero mode (only on one process in multi-GPU setup).""" with open(path, "w+") as file: file.write(content) + + +def load_value_from_file(path: str, split_path_key: str = ":", split_key_parts: str = "/") -> Dict: + """Load a value from a file. The path can point to elements within the file (see split_path_key + parameter) and that can be nested (see split_key_parts parameter). For now, only .json files + are supported. + + Args: + path: path to the file (and data within the file) + split_path_key: split the path on this value to get the path to the file and the key within the file + split_key_parts: the value to split the key on to get the nested keys + """ + + parts_path = path.split(split_path_key, maxsplit=1) + file_extension = os.path.splitext(parts_path[0])[1] + if file_extension == ".json": + with open(parts_path[0], "r") as f: + data = json.load(f) + else: + raise ValueError(f"Expected .json file, got {file_extension}") + + if len(parts_path) == 1: + return data + + keys = parts_path[1].split(split_key_parts) + for key in keys: + data = data[key] + return data + + +def replace_sys_args_with_values_from_files( + load_prefix: str = "LOAD_ARG:", + load_multi_prefix: str = "LOAD_MULTI_ARG:", + **load_value_from_file_kwargs, +) -> None: + """Replaces arguments in sys.argv with values loaded from files. + + Examples: + # config.json contains {"a": 1, "b": 2} + python train.py LOAD_ARG:job_return_value.json + # this will pass "{a:1,b:2}" as the first argument to train.py + + # config.json contains [1, 2, 3] + python train.py LOAD_MULTI_ARG:job_return_value.json + # this will pass "1,2,3" as the first argument to train.py + + # config.json contains {"model": {"ouput_dir": ["path1", "path2"], f1: [0.7, 0.6]}} + python train.py load_model=LOAD_ARG:job_return_value.json:model/output_dir + # this will pass "load_model=path1,path2" to train.py + + Args: + load_prefix: the prefix to use for loading a single value from a file + load_multi_prefix: the prefix to use for loading a list of values from a file + **load_value_from_file_kwargs: additional kwargs to pass to load_value_from_file + """ + + updated_args = [] + for arg in sys.argv[1:]: + is_multirun_arg = False + if load_prefix in arg: + parts = arg.split(load_prefix, maxsplit=1) + elif load_multi_prefix in arg: + parts = arg.split(load_multi_prefix, maxsplit=1) + is_multirun_arg = True + else: + updated_args.append(arg) + continue + if len(parts) == 2: + log.warning(f'Replacing argument value for "{parts[0]}" with content from {parts[1]}') + json_value = load_value_from_file(parts[1], **load_value_from_file_kwargs) + json_value_str = json.dumps(json_value) + # replace quotes and spaces + json_value_str = json_value_str.replace('"', "").replace(" ", "") + # remove outer brackets + if is_multirun_arg: + if not isinstance(json_value, list): + raise ValueError( + f"Expected list for multirun argument, got {type(json_value)}. If you just want " + f"to set a single value, use {load_prefix} instead of {load_multi_prefix}." + ) + json_value_str = json_value_str[1:-1] + # add outer quotes + modified_arg = f"{parts[0]}{json_value_str}" + updated_args.append(modified_arg) + else: + updated_args.append(arg) + # Set sys.argv to the updated arguments + sys.argv = [sys.argv[0]] + updated_args