Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
STYLE: Rename the EddyMotionGPR instance to honor better its classname
Browse files Browse the repository at this point in the history
Rename the `EddyMotionGPR` instance to honor better its classname.
  • Loading branch information
jhlegarreta committed Oct 25, 2024
1 parent 4850266 commit 8835e63
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions scripts/dwi_gp_estimation_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def cross_validate(
X: np.ndarray,
y: np.ndarray,
cv: int,
gpm: EddyMotionGPR,
gpr: EddyMotionGPR,
) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]:
"""
Perform the experiment by estimating the dMRI signal using a Gaussian process model.
Expand All @@ -60,7 +60,7 @@ def cross_validate(
DWI signal.
cv : :obj:`int`
number of folds
gpm : obj:`~eddymotion.model._sklearn.EddyMotionGPR`
gpr : obj:`~eddymotion.model._sklearn.EddyMotionGPR`
The eddymotion Gaussian process regressor object.
Returns
Expand All @@ -71,7 +71,7 @@ def cross_validate(
"""

rkf = RepeatedKFold(n_splits=cv, n_repeats=120 // cv)
scores = cross_val_score(gpm, X, y, scoring="neg_root_mean_squared_error", cv=rkf)
scores = cross_val_score(gpr, X, y, scoring="neg_root_mean_squared_error", cv=rkf)
return scores


Expand Down Expand Up @@ -176,7 +176,7 @@ def main() -> None:
a = 1.15
lambda_s = 120
alpha = 100
gpm = EddyMotionGPR(
gpr = EddyMotionGPR(
kernel=SphericalKriging(a=a, lambda_s=lambda_s),
alpha=alpha,
optimizer=None,
Expand All @@ -186,7 +186,7 @@ def main() -> None:
scores = defaultdict(list, {})
for n in args.kfold:
for i in range(args.repeats):
cv_scores = -1.0 * cross_validate(X, y.T, n, gpm)
cv_scores = -1.0 * cross_validate(X, y.T, n, gpr)
scores["rmse"] += cv_scores.tolist()
scores["repeat"] += [i] * len(cv_scores)
scores["n_folds"] += [n] * len(cv_scores)
Expand All @@ -202,7 +202,7 @@ def main() -> None:
print(grouped[["rmse"]].std())

cv = KFold(n_splits=3, shuffle=False, random_state=None)
predictions = cross_val_predict(gpm, X, y.T, cv=cv)
predictions = cross_val_predict(gpr, X, y.T, cv=cv)
testsims.serialize_dwi(predictions.T, args.dwi_pred_data_fname)


Expand Down

0 comments on commit 8835e63

Please sign in to comment.