diff --git a/src/nifreeze/model/gpr.py b/src/nifreeze/model/gpr.py index 19f88bf..905f976 100644 --- a/src/nifreeze/model/gpr.py +++ b/src/nifreeze/model/gpr.py @@ -30,6 +30,7 @@ import numpy as np from scipy import optimize from scipy.optimize._minimize import Bounds +from scipy.spatial.distance import cdist, pdist, squareform from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import ( Hyperparameter, @@ -40,6 +41,8 @@ BOUNDS_A: tuple[float, float] = (0.1, 2.35) """The limits for the parameter *a* (angular distance in rad).""" +BOUNDS_ELL: tuple[float, float] = (0.1, 2.35) +"""The limits for the parameter *$\ell$* (shell distance in $s/mm^2$).""" BOUNDS_LAMBDA: tuple[float, float] = (1e-3, 1000) """The limits for the parameter λ (signal scaling factor).""" THETA_EPSILON: float = 1e-5 @@ -469,6 +472,121 @@ def __repr__(self) -> str: return f"SphericalKriging (a={self.beta_a}, λ={self.beta_l})" +class SquaredExponentialKriging(Kernel): + """A scikit-learn's kernel for DWI signals.""" + + def __init__( + self, + beta_ell: float = 1.38, + beta_l: float = 0.5, + ell_bounds: tuple[float, float] = BOUNDS_ELL, + l_bounds: tuple[float, float] = BOUNDS_LAMBDA, + ): + r""" + Initialize a spherical Kriging kernel. + + Parameters + ---------- + beta_ell : :obj:`float`, optional + Minimum angle in rads. + beta_l : :obj:`float`, optional + The :math:`\lambda` hyperparameter. + ell_bounds : :obj:`tuple`, optional + Bounds for the :math:`\ell` parameter. + l_bounds : :obj:`tuple`, optional + Bounds for the :math:`\lambda` hyperparameter. + + """ + self.beta_ell = beta_ell + self.beta_l = beta_l + self.a_bounds = ell_bounds + self.l_bounds = l_bounds + + @property + def hyperparameter_ell(self) -> Hyperparameter: + return Hyperparameter("beta_ell", "numeric", self.a_bounds) + + @property + def hyperparameter_l(self) -> Hyperparameter: + return Hyperparameter("beta_l", "numeric", self.l_bounds) + + def __call__( + self, X: np.ndarray, Y: np.ndarray | None = None, eval_gradient: bool = False + ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: + """ + Return the kernel K(X, Y) and optionally its gradient. + + Parameters + ---------- + X : :obj:`~numpy.ndarray` + Gradient wighting values (X) + Y : :obj:`~numpy.ndarray`, optional + Gradient wighting values (Y, optional) + eval_gradient : :obj:`bool`, optional + Determines whether the gradient with respect to the log of + the kernel hyperparameter is computed. + Only supported when Y is ``None``. + + Returns + ------- + K : :obj:`~numpy.ndarray` of shape (n_samples_X, n_samples_Y) + Kernel k(X, Y) + + K_gradient : :obj:`~numpy.ndarray` of shape (n_samples_X, n_samples_X, n_dims),\ + optional + The gradient of the kernel k(X, X) with respect to the log of the + hyperparameter of the kernel. Only returned when ``eval_gradient`` + is True. + + """ + + dists = compute_shell_distance(X, Y=Y) + C_b = squared_exponential_covariance(dists, self.beta_ell) + + if Y is None: + C_b = squareform(C_b) + np.fill_diagonal(C_b, 1) + + if not eval_gradient: + return self.beta_l * C_b + + # Looking at this + # https://github.com/scikit-learn/scikit-learn/blob/1e6a81f322f1821cc605a18b08fcc198c7d63c97/sklearn/gaussian_process/kernels.py#L1574 + # Not sure the derivative is clear to me. IMO it should be + # \frac{d}{dx} \left( e^{-\frac{cte}{x^2}} \right) = \frac{2 cte}{x^3} \cdot e^{-\frac{cte}{x^2}} + # where x is ell, and cte is 0.5 * (\log b - \log b')^2 + + K_gradient = 1 # ToDo + + return self.beta_l * C_b, K_gradient + + def diag(self, X: np.ndarray) -> np.ndarray: + """Returns the diagonal of the kernel k(X, X). + + The result of this method is identical to np.diag(self(X)); however, + it can be evaluated more efficiently since only the diagonal is + evaluated. + + Parameters + ---------- + X : :obj:`~numpy.ndarray` of shape (n_samples_X, n_features) + Left argument of the returned kernel k(X, Y) + + Returns + ------- + K_diag : :obj:`~numpy.ndarray` of shape (n_samples_X,) + Diagonal of kernel k(X, X) + """ + return self.beta_l * np.ones(X.shape[0]) + + def is_stationary(self) -> bool: + """Returns whether the kernel is stationary.""" + return True + + def __repr__(self) -> str: + return f"SquaredExponentialKriging (wll={self.beta_ell}, λ={self.beta_l})" + + def exponential_covariance(theta: np.ndarray, a: float) -> np.ndarray: r""" Compute the exponential covariance for given distances and scale parameter. @@ -590,3 +708,71 @@ def compute_pairwise_angles( thetas = np.arccos(np.abs(cosines)) if closest_polarity else np.arccos(cosines) thetas[np.abs(thetas) < THETA_EPSILON] = 0.0 return thetas + + +def squared_exponential_covariance( + shell_distance: np.ndarray, + ell: float, +) -> np.ndarray: + r"""Compute the squared exponential covariance for given diffusion gradient + encoding weighting distances and scale parameter. + + Implements :math:`C_{b}`, following Eq. (15) in [Andersson15]_: + + .. math:: + + C_{b}(b, b'; \ell) = \exp\left( - \frac{(\log b - \log b')^2}{2 \ell^2} \right) + + The squared exponential covariance function is sometimes called radial basis + function (RBF) or Gaussian kernel. + + Parameters + ---------- + shell_distance : :obj:`~numpy.ndarray` of shape (n_samples_X, n_features) + Input data. + ell : float + Distance parameter where the covariance function goes to zero. + + Returns + ------- + :obj:`~numpy.ndarray` + Squared exponential covariance values for the input distances. + """ + + return np.exp(-0.5 * (shell_distance / (ell**2))) + + +def compute_shell_distance(X, Y=None): + r"""Compute pairwise angles across diffusion gradient encoding weighting + values. + + Following Eq. (15) in [Andersson15]_, computes the distance between the log + values of the diffusion gradient encoding weighting values: + + .. math:: + + \log b - \log b' + + Parameters + ---------- + X : :obj:`~numpy.ndarray` of shape (n_samples_X, n_features) + Input data. + Y : :obj:`~numpy.ndarray` of shape (n_samples_Y, n_features), optional + Input data. If ``None``, the output will be the pairwise + similarities between all samples in ``X``. + + Returns + ------- + :obj:`~numpy.ndarray` + Pairwise distances of diffusion gradient encoding weighting values. + """ + + # ToDo + # scikit-learn RBF call includes $\ell$ here; fine, but then I do not get + # the derivative computation the way they compute it + if Y is None: + dists = pdist(np.log(X), metric="sqeuclidean") + else: + dists = cdist(np.log(X), np.log(Y), metric="sqeuclidean") + + return dists diff --git a/test/test_gpr.py b/test/test_gpr.py index 8d19974..0ba0c35 100644 --- a/test/test_gpr.py +++ b/test/test_gpr.py @@ -271,8 +271,8 @@ def test_compute_pairwise_angles(bvecs1, bvecs2, closest_polarity, expected): np.testing.assert_array_almost_equal(obtained, expected, decimal=2) -@pytest.mark.parametrize("covariance", ["Spherical", "Exponential"]) -def test_kernel(repodata, covariance): +@pytest.mark.parametrize("covariance", ["Spherical", "Exponential", "SquaredExponential"]) +def test_kernel_single_shell(repodata, covariance): """Check kernel construction.""" bvals, bvecs = read_bvals_bvecs( @@ -296,3 +296,65 @@ def test_kernel(repodata, covariance): K_predict = kernel(bvecs, bvecs[10:14, ...]) assert K_predict.shape == (K.shape[0], 4) + + +# ToDo +@pytest.mark.parametrize( + ("bvals1", "bvals2", "expected"), + [ + ( + np.array( + [ + [1000, 1000, 1000, 1000], + ] + ), + None, + np.array( + [ + [0, 0, 0, 0], + ] + ), + ), + ( + np.array( + [ + [1000, 1000, 1000, 1000], + [2000, 2000, 2000], + ] + ), + None, + np.array( + [ + [1000, 1000, 1000, 1000], + ] + ), + ), + ], +) +def test_compute_shell_distance(bvals1, bvals2, expected): + + obtained = gpr.compute_shell_distance(bvals1, bvals2) + + if bvals2 is not None: + assert (bvals1.shape[0], bvals2.shape[0]) == obtained.shape + assert obtained.shape == expected.shape + np.testing.assert_array_almost_equal(obtained, expected, decimal=2) + + +@pytest.mark.parametrize("covariance", ["SquaredExponential"]) +def test_kernel_multi_shell(repodata, covariance): + """Check kernel construction.""" + + bvals, bvecs = read_bvals_bvecs( + str(repodata / "ds000114_multishell.bval"), + str(repodata / "ds000114_multishell.bvec"), + ) + + bvals = bvals[bvals > 10] + + KernelType = getattr(gpr, f"{covariance}Kriging") + kernel = KernelType() + + K = kernel(bvals) + + assert K.shape == (bvals.shape[0],) * 2