Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
fix: deep revision and refactor of the GP modeling
Browse files Browse the repository at this point in the history
I tried to migrate the cross-validation to only use Scikit-learn CV
utilities.
At first, Scikit-learn wanted our
``eddymotion.model._dipy.GaussianProcessModel`` to be a sklearn's
Estimator.
That made me realize that the boundaries between Scikit-learn, DIPY, and
eddymotion weren't clear.

This commit:

* Separates our old ``_dipy`` module into two modules, moving all
  Scikit-learn-related types into a new ``_sklearn`` module.
* Type-annotates and completes docstrings of most of the code.
* Updates the test script ``dwi_estimation_error_analysis.py`` to employ
  the new code.
  • Loading branch information
oesteban committed Oct 23, 2024
1 parent 90c8603 commit de3888f
Show file tree
Hide file tree
Showing 5 changed files with 519 additions and 500 deletions.
36 changes: 20 additions & 16 deletions scripts/dwi_estimation_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@
from dipy.core.sphere import HemiSphere, Sphere, disperse_charges
from dipy.sims.voxel import all_tensor_evecs, single_tensor
from matplotlib import pyplot as plt
from nireports.reportlets.modality.dwi import nii_to_carpetplot_data
from nireports.reportlets.nuisance import plot_carpet
from scipy.stats import pearsonr
from sklearn.metrics import root_mean_squared_error
from sklearn.model_selection import KFold, cross_val_score
from sklearn.model_selection import KFold, RepeatedKFold, cross_val_score

from eddymotion.model._dipy import GaussianProcessModel
from eddymotion.model._sklearn import (
EddyMotionGPR,
SphericalKriging,
)


def add_b0(bvals: np.ndarray, bvecs: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -300,7 +301,6 @@ def perform_experiment(
rng = np.random.default_rng(1234)

# Define the Gaussian process model parameter
kernel_model = "spherical"
lambda_s = 2.0
a = 1.0
sigma_sq = 0.5
Expand All @@ -326,9 +326,12 @@ def perform_experiment(
kf = KFold(n_splits=n, shuffle=False)

# Define the Gaussian process model instance
gp_model = GaussianProcessModel(
kernel_model=kernel_model, lambda_s=lambda_s, a=a, sigma_sq=sigma_sq
gp_model = EddyMotionGPR(
kernel=SphericalKriging(a=a, lambda_s=lambda_s),
alpha=sigma_sq,
optimizer=None,
)

_data = []
for _, (train_index, test_index) in enumerate(kf.split(nzero_bvecs)):
# Create the training mask leaving out the requested number of samples
Expand Down Expand Up @@ -382,20 +385,18 @@ def cross_validate(
"""

gp_params = {
"weighting": "exponential",
"lambda_s": 2.0,
"a": 1.0,
"sigma_sq": 2.0,
}

signal = single_tensor(gtab, S0=S0, evals=evals1, evecs=evecs, snr=snr)
gpm = GaussianProcessModel(**gp_params)
gpm = EddyMotionGPR(
kernel=SphericalKriging(a=2.15, lambda_s=120),
alpha=50,
optimizer=None,
)

X = gtab[~gtab.b0s_mask].bvecs
y = signal[~gtab.b0s_mask]

scores = cross_val_score(gpm, X, y, scoring="neg_root_mean_squared_error", cv=cv)
rkf = RepeatedKFold(n_splits=cv, n_repeats=120 // cv)
scores = cross_val_score(gpm, X, y, scoring="neg_root_mean_squared_error", cv=rkf)
return scores


Expand Down Expand Up @@ -486,6 +487,9 @@ def plot_error(


def plot_estimation_carpet(gt_nii, gp_nii, gtab, suptitle, **kwargs):
from nireports.reportlets.modality.dwi import nii_to_carpetplot_data
from nireports.reportlets.nuisance import plot_carpet

fig = plt.figure(layout="tight")
gs = gridspec.GridSpec(ncols=1, nrows=2, figure=fig)
fig.suptitle(suptitle)
Expand Down
Loading

0 comments on commit de3888f

Please sign in to comment.