Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow extra_filter argument in the settings json by specifying a file_path::function_name string #88

Merged
merged 1 commit into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,4 @@ isort.required-imports = ["from __future__ import annotations"]
"**.ipynb" = ["B008", "T20", "I002", "E402", "E703", "B018"]
"src/egamma_tnp/nanoaod_efficiency.py" = ["PLW2901"]
"src/egamma_tnp/__init__.py" = ["E402"]
"tests/example_extra_filter.py" = ["T201"]
34 changes: 34 additions & 0 deletions src/egamma_tnp/utils/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,37 @@ def filter_class_args(class_, args):
return {k: v for k, v in args.items() if k in sig.parameters}


def load_function_from_file(function_path):
"""Load a function from a file and return it."""
# Split the file path and function name
if "::" in function_path:
file_path, function_name = function_path.split("::")
else:
logger.error(f"Function name not provided in the format 'path::function': {function_path}")
raise ValueError(f"Function name not provided in the format 'path::function': {function_path}")

# Check if the file exists
if not os.path.exists(file_path):
logger.error(f"Function file not found: {file_path}")
raise FileNotFoundError(f"Function file not found: {file_path}")

# Load and execute the file content
with open(file_path) as file:
code = compile(file.read(), file_path, "exec")
local_scope = {} # Use a restricted local scope
exec(code, {}, local_scope) # Execute code in isolated scope

# Return the function if a name is provided, otherwise return all loaded objects
if function_name:
if function_name in local_scope:
return local_scope[function_name]
else:
logger.error(f"Function '{function_name}' not found in {file_path}")
raise ValueError(f"Function '{function_name}' not found in {file_path}")
else:
return local_scope


def initialize_class(config, args, fileset):
"""Initialize the appropriate Tag and Probe class based on the workflow specified in the config."""
class_map = {
Expand All @@ -104,6 +135,9 @@ def initialize_class(config, args, fileset):
workflow = class_map[class_name]
class_args = config["workflow_args"] | filter_class_args(workflow, vars(args))
class_args.pop("fileset")
if args.extra_filter:
extra_filter = load_function_from_file(args.extra_filter)
class_args["extra_filter"] = extra_filter
logger.info(f"Initializing workflow {workflow} with args: {class_args}")
return workflow(fileset=fileset, **class_args)

Expand Down
8 changes: 8 additions & 0 deletions tests/example_extra_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import annotations


def extra_filter(events, extra_filter_arg1, extra_filter_arg2):
print("I'm an extra filter")
print(f"extra_filter_arg1: {extra_filter_arg1}")
print(f"extra_filter_arg2: {extra_filter_arg2}")
return events
7 changes: 5 additions & 2 deletions tests/example_settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
"cutbased_id": null,
"extra_tags_mask": null,
"extra_probes_mask": null,
"extra_filter": null,
"extra_filter_args": {},
"extra_filter": "tests/example_extra_filter.py::extra_filter",
"extra_filter_args": {
"extra_filter_arg1": 1,
"extra_filter_arg2": 2
},
"use_sc_eta": true,
"use_sc_phi": false,
"avoid_ecal_transition_tags": true,
Expand Down
Loading