Skip to content

Commit

Permalink
Improve parameter replacement in set_fn function
Browse files Browse the repository at this point in the history
- Use regular expressions with word boundaries to replace parameter names in the source function.
- Ensure robust replacement even in cases where parameter names are surrounded by characters other than spaces, like parentheses.
- This resolves issues with functions that have parameters within expressions or parentheses.
  • Loading branch information
ziatdinovmax authored Oct 9, 2023
1 parent c6dab96 commit 3526b82
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions gpax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def uniform_dist(low: float = None,
high = high if high is not None else input_vec.max()

return numpyro.distributions.Uniform(low, high)


def set_fn(func: Callable) -> Callable:
"""
Expand All @@ -289,9 +289,9 @@ def set_fn(func: Callable) -> Callable:
source = inspect.getsource(func).split("\n", 1)[1]
source = " " + source.replace("\n", "\n ")

# Replace each parameter name with its dictionary lookup
# Replace each parameter name with its dictionary lookup using regex
for name in params_names:
source = source.replace(f" {name}", f' params["{name}"]')
source = re.sub(rf'\b{name}\b', f'params["{name}"]', source)

# Combine to get the full source
transformed_code += source
Expand All @@ -302,7 +302,7 @@ def set_fn(func: Callable) -> Callable:

# Return the transformed function
return local_namespace[func.__name__]


def auto_normal_priors(func: Callable, loc: float = 0.0, scale: float = 1.0) -> Callable:
"""
Expand Down

0 comments on commit 3526b82

Please sign in to comment.