diff --git a/docs/notebooks/dmri_covariance.ipynb b/docs/notebooks/dmri_covariance.ipynb index aeecc707..e27ceff8 100644 --- a/docs/notebooks/dmri_covariance.ipynb +++ b/docs/notebooks/dmri_covariance.ipynb @@ -24,7 +24,7 @@ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", - "from eddymotion.model._sklearn import (\n", + "from eddymotion.model.gpr import (\n", " compute_pairwise_angles,\n", " exponential_covariance,\n", " spherical_covariance,\n", @@ -345,7 +345,7 @@ } ], "source": [ - "from eddymotion.model._sklearn import EddyMotionGPR, SphericalKriging\n", + "from eddymotion.model.gpr import EddyMotionGPR, SphericalKriging\n", "\n", "K = SphericalKriging(beta_a=PARAMETER_SPHERICAL_a, beta_l=PARAMETER_lambda)(X_real)\n", "K -= K.min()\n", diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index 0d68dba7..2591f58d 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -36,7 +36,7 @@ import pandas as pd from sklearn.model_selection import KFold, RepeatedKFold, cross_val_predict, cross_val_score -from eddymotion.model._sklearn import ( +from eddymotion.model.gpr import ( EddyMotionGPR, SphericalKriging, ) @@ -63,7 +63,7 @@ def cross_validate( Number of folds. n_repeats : :obj:`int` Number of times the cross-validator needs to be repeated. - gpr : obj:`~eddymotion.model._sklearn.EddyMotionGPR` + gpr : obj:`~eddymotion.model.gpr.EddyMotionGPR` The eddymotion Gaussian process regressor object. Returns diff --git a/scripts/dwi_gp_estimation_simulated_signal.py b/scripts/dwi_gp_estimation_simulated_signal.py index 3b534c68..ade883f7 100644 --- a/scripts/dwi_gp_estimation_simulated_signal.py +++ b/scripts/dwi_gp_estimation_simulated_signal.py @@ -33,7 +33,7 @@ import numpy as np from dipy.core.sphere import Sphere -from eddymotion.model._sklearn import EddyMotionGPR, SphericalKriging +from eddymotion.model.gpr import EddyMotionGPR, SphericalKriging from eddymotion.testing import simulations as testsims SAMPLING_DIRECTIONS = 200 diff --git a/src/eddymotion/estimator.py b/src/eddymotion/estimator.py index 0322fc05..ada6bccb 100644 --- a/src/eddymotion/estimator.py +++ b/src/eddymotion/estimator.py @@ -30,7 +30,7 @@ from eddymotion import utils as eutils from eddymotion.data.splitting import lovo_split -from eddymotion.model import ModelFactory +from eddymotion.model.base import ModelFactory from eddymotion.registration.ants import _prepare_registration_data, _run_registration diff --git a/src/eddymotion/model/_dipy.py b/src/eddymotion/model/_dipy.py index f1a1f0d9..d7eb1773 100644 --- a/src/eddymotion/model/_dipy.py +++ b/src/eddymotion/model/_dipy.py @@ -31,7 +31,7 @@ from dipy.reconst.base import ReconstModel from sklearn.gaussian_process import GaussianProcessRegressor -from eddymotion.model._sklearn import ( +from eddymotion.model.gpr import ( EddyMotionGPR, ExponentialKriging, SphericalKriging, diff --git a/src/eddymotion/model/_sklearn.py b/src/eddymotion/model/gpr.py similarity index 98% rename from src/eddymotion/model/_sklearn.py rename to src/eddymotion/model/gpr.py index a3b333ad..ca1242f0 100644 --- a/src/eddymotion/model/_sklearn.py +++ b/src/eddymotion/model/gpr.py @@ -23,6 +23,7 @@ r""" Derivations from scikit-learn for Gaussian Processes. + Gaussian Process Model: Pairwise orientation angles --------------------------------------------------- Squared Exponential covariance kernel @@ -101,10 +102,10 @@ from sklearn.metrics.pairwise import cosine_similarity from sklearn.utils._param_validation import Interval, StrOptions -BOUNDS_A: tuple[float, float] = (0.1, np.pi) -"""The limits for the parameter *a*.""" +BOUNDS_A: tuple[float, float] = (0.1, 0.75 * np.pi) +"""The limits for the parameter *a* (angular distance).""" BOUNDS_LAMBDA: tuple[float, float] = (1e-3, 1000) -"""The limits for the parameter lambda.""" +"""The limits for the parameter λ (signal scaling factor).""" THETA_EPSILON: float = 1e-5 """Minimum nonzero angle.""" LBFGS_CONFIGURABLE_OPTIONS = {"disp", "maxiter", "ftol", "gtol"} @@ -143,8 +144,7 @@ class EddyMotionGPR(GaussianProcessRegressor): In principle, Scikit-Learn's implementation normalizes the training data as in [Andersson15]_ (see - `FSL's souce code `__). + `FSL's souce code `__). From their paper (p. 167, end of first column): Typically one just substracts the mean (:math:`\bar{\mathbf{f}}`) @@ -161,7 +161,7 @@ class EddyMotionGPR(GaussianProcessRegressor): I believe this is overlooked in [Andersson15]_, or they actually did not use analytical gradient-descent: - _A note on optimisation_ + *A note on optimisation* It is suggested, for example in Rasmussen and Williams (2006), that an optimisation method that uses derivative information should be @@ -184,7 +184,7 @@ class EddyMotionGPR(GaussianProcessRegressor): "optimizer": [StrOptions(SUPPORTED_OPTIMIZERS), callable, None], "n_restarts_optimizer": [Interval(Integral, 0, None, closed="left")], "copy_X_train": ["boolean"], - "zeromean_y": ["boolean"], + "normalize_y": ["boolean"], "n_targets": [Interval(Integral, 1, None, closed="left"), None], "random_state": ["random_state"], } diff --git a/test/test_sklearn.py b/test/test_gpr.py similarity index 98% rename from test/test_sklearn.py rename to test/test_gpr.py index 277ba0b9..efc42b26 100644 --- a/test/test_sklearn.py +++ b/test/test_gpr.py @@ -26,7 +26,7 @@ import pytest from dipy.io import read_bvals_bvecs -from eddymotion.model import _sklearn as ems +from eddymotion.model import gpr GradientTablePatch = namedtuple("gtab", ["bvals", "bvecs"]) @@ -263,7 +263,7 @@ def test_compute_pairwise_angles(bvecs1, bvecs2, closest_polarity, expected): if bvecs2 is not None: _bvecs2 = (bvecs2 / np.linalg.norm(bvecs2, axis=0)).T - obtained = ems.compute_pairwise_angles(_bvecs1, _bvecs2, closest_polarity) + obtained = gpr.compute_pairwise_angles(_bvecs1, _bvecs2, closest_polarity) if _bvecs2 is not None: assert (_bvecs1.shape[0], _bvecs2.shape[0]) == obtained.shape @@ -282,7 +282,7 @@ def test_kernel(repodata, covariance): bvecs = bvecs[bvals > 10] - KernelType = getattr(ems, f"{covariance}Kriging") + KernelType = getattr(gpr, f"{covariance}Kriging") kernel = KernelType() K = kernel(bvecs)