diff --git a/pyproject.toml b/pyproject.toml index 8f21c1c6..2d084e0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,3 +126,4 @@ isort.required-imports = ["from __future__ import annotations"] "**.ipynb" = ["B008", "T20", "I002", "E402", "E703", "B018"] "src/egamma_tnp/nanoaod_efficiency.py" = ["PLW2901"] "src/egamma_tnp/__init__.py" = ["E402"] +"tests/example_extra_filter.py" = ["T201"] diff --git a/src/egamma_tnp/utils/runner_utils.py b/src/egamma_tnp/utils/runner_utils.py index 9a032e5c..a8c0d9a2 100644 --- a/src/egamma_tnp/utils/runner_utils.py +++ b/src/egamma_tnp/utils/runner_utils.py @@ -92,6 +92,37 @@ def filter_class_args(class_, args): return {k: v for k, v in args.items() if k in sig.parameters} +def load_function_from_file(function_path): + """Load a function from a file and return it.""" + # Split the file path and function name + if "::" in function_path: + file_path, function_name = function_path.split("::") + else: + logger.error(f"Function name not provided in the format 'path::function': {function_path}") + raise ValueError(f"Function name not provided in the format 'path::function': {function_path}") + + # Check if the file exists + if not os.path.exists(file_path): + logger.error(f"Function file not found: {file_path}") + raise FileNotFoundError(f"Function file not found: {file_path}") + + # Load and execute the file content + with open(file_path) as file: + code = compile(file.read(), file_path, "exec") + local_scope = {} # Use a restricted local scope + exec(code, {}, local_scope) # Execute code in isolated scope + + # Return the function if a name is provided, otherwise return all loaded objects + if function_name: + if function_name in local_scope: + return local_scope[function_name] + else: + logger.error(f"Function '{function_name}' not found in {file_path}") + raise ValueError(f"Function '{function_name}' not found in {file_path}") + else: + return local_scope + + def initialize_class(config, args, fileset): """Initialize the appropriate Tag and Probe class based on the workflow specified in the config.""" class_map = { @@ -104,6 +135,9 @@ def initialize_class(config, args, fileset): workflow = class_map[class_name] class_args = config["workflow_args"] | filter_class_args(workflow, vars(args)) class_args.pop("fileset") + if args.extra_filter: + extra_filter = load_function_from_file(args.extra_filter) + class_args["extra_filter"] = extra_filter logger.info(f"Initializing workflow {workflow} with args: {class_args}") return workflow(fileset=fileset, **class_args) diff --git a/tests/example_extra_filter.py b/tests/example_extra_filter.py new file mode 100644 index 00000000..27af99cd --- /dev/null +++ b/tests/example_extra_filter.py @@ -0,0 +1,8 @@ +from __future__ import annotations + + +def extra_filter(events, extra_filter_arg1, extra_filter_arg2): + print("I'm an extra filter") + print(f"extra_filter_arg1: {extra_filter_arg1}") + print(f"extra_filter_arg2: {extra_filter_arg2}") + return events diff --git a/tests/example_settings.json b/tests/example_settings.json index bae5b7c2..f4fb91e0 100644 --- a/tests/example_settings.json +++ b/tests/example_settings.json @@ -6,8 +6,11 @@ "cutbased_id": null, "extra_tags_mask": null, "extra_probes_mask": null, - "extra_filter": null, - "extra_filter_args": {}, + "extra_filter": "tests/example_extra_filter.py::extra_filter", + "extra_filter_args": { + "extra_filter_arg1": 1, + "extra_filter_arg2": 2 + }, "use_sc_eta": true, "use_sc_phi": false, "avoid_ecal_transition_tags": true,