diff --git a/gpax/utils.py b/gpax/utils.py index f7b11c4..1f9bdef 100644 --- a/gpax/utils.py +++ b/gpax/utils.py @@ -265,7 +265,7 @@ def uniform_dist(low: float = None, high = high if high is not None else input_vec.max() return numpyro.distributions.Uniform(low, high) - + def set_fn(func: Callable) -> Callable: """ @@ -289,9 +289,9 @@ def set_fn(func: Callable) -> Callable: source = inspect.getsource(func).split("\n", 1)[1] source = " " + source.replace("\n", "\n ") - # Replace each parameter name with its dictionary lookup + # Replace each parameter name with its dictionary lookup using regex for name in params_names: - source = source.replace(f" {name}", f' params["{name}"]') + source = re.sub(rf'\b{name}\b', f'params["{name}"]', source) # Combine to get the full source transformed_code += source @@ -302,7 +302,7 @@ def set_fn(func: Callable) -> Callable: # Return the transformed function return local_namespace[func.__name__] - + def auto_normal_priors(func: Callable, loc: float = 0.0, scale: float = 1.0) -> Callable: """