Skip to content

Commit

Permalink
allow extra_filter argument in the settings json by specifying a file…
Browse files Browse the repository at this point in the history
…_path::function_name string
  • Loading branch information
ikrommyd committed Sep 14, 2024
1 parent 3f73b6d commit d0e15b0
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 2 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,4 @@ isort.required-imports = ["from __future__ import annotations"]
"**.ipynb" = ["B008", "T20", "I002", "E402", "E703", "B018"]
"src/egamma_tnp/nanoaod_efficiency.py" = ["PLW2901"]
"src/egamma_tnp/__init__.py" = ["E402"]
"tests/example_extra_filter.py" = ["T201"]
34 changes: 34 additions & 0 deletions src/egamma_tnp/utils/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,37 @@ def filter_class_args(class_, args):
return {k: v for k, v in args.items() if k in sig.parameters}


def load_function_from_file(function_path):
"""Load a function from a file and return it."""
# Split the file path and function name
if "::" in function_path:
file_path, function_name = function_path.split("::")
else:
logger.error(f"Function name not provided in the format 'path::function': {function_path}")
raise ValueError(f"Function name not provided in the format 'path::function': {function_path}")

# Check if the file exists
if not os.path.exists(file_path):
logger.error(f"Function file not found: {file_path}")
raise FileNotFoundError(f"Function file not found: {file_path}")

# Load and execute the file content
with open(file_path) as file:
code = compile(file.read(), file_path, "exec")
local_scope = {} # Use a restricted local scope
exec(code, {}, local_scope) # Execute code in isolated scope

# Return the function if a name is provided, otherwise return all loaded objects
if function_name:
if function_name in local_scope:
return local_scope[function_name]
else:
logger.error(f"Function '{function_name}' not found in {file_path}")
raise ValueError(f"Function '{function_name}' not found in {file_path}")
else:
return local_scope


def initialize_class(config, args, fileset):
"""Initialize the appropriate Tag and Probe class based on the workflow specified in the config."""
class_map = {
Expand All @@ -104,6 +135,9 @@ def initialize_class(config, args, fileset):
workflow = class_map[class_name]
class_args = config["workflow_args"] | filter_class_args(workflow, vars(args))
class_args.pop("fileset")
if args.extra_filter:
extra_filter = load_function_from_file(args.extra_filter)
class_args["extra_filter"] = extra_filter
logger.info(f"Initializing workflow {workflow} with args: {class_args}")
return workflow(fileset=fileset, **class_args)

Expand Down
8 changes: 8 additions & 0 deletions tests/example_extra_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import annotations


def extra_filter(events, extra_filter_arg1, extra_filter_arg2):
print("I'm an extra filter")
print(f"extra_filter_arg1: {extra_filter_arg1}")
print(f"extra_filter_arg2: {extra_filter_arg2}")
return events
7 changes: 5 additions & 2 deletions tests/example_settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
"cutbased_id": null,
"extra_tags_mask": null,
"extra_probes_mask": null,
"extra_filter": null,
"extra_filter_args": {},
"extra_filter": "tests/example_extra_filter.py::extra_filter",
"extra_filter_args": {
"extra_filter_arg1": 1,
"extra_filter_arg2": 2
},
"use_sc_eta": true,
"use_sc_phi": false,
"avoid_ecal_transition_tags": true,
Expand Down

0 comments on commit d0e15b0

Please sign in to comment.