diff --git a/gpax/utils/fn.py b/gpax/utils/fn.py index b4f1ca5..3119e09 100644 --- a/gpax/utils/fn.py +++ b/gpax/utils/fn.py @@ -10,9 +10,10 @@ import inspect import re -from typing import List, Callable, Optional +from typing import List, Callable, Optional, Dict import jax +import jax.numpy as jnp from ..kernels.kernels import square_scaled_distance, add_jitter, _sqrt