From 8835e63dc89f38b7d4165376c3b0a6db72712167 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Thu, 24 Oct 2024 20:48:35 -0400 Subject: [PATCH] STYLE: Rename the `EddyMotionGPR` instance to honor better its classname Rename the `EddyMotionGPR` instance to honor better its classname. --- scripts/dwi_gp_estimation_error_analysis.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index 10f70464..2d3dd483 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -47,7 +47,7 @@ def cross_validate( X: np.ndarray, y: np.ndarray, cv: int, - gpm: EddyMotionGPR, + gpr: EddyMotionGPR, ) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]: """ Perform the experiment by estimating the dMRI signal using a Gaussian process model. @@ -60,7 +60,7 @@ def cross_validate( DWI signal. cv : :obj:`int` number of folds - gpm : obj:`~eddymotion.model._sklearn.EddyMotionGPR` + gpr : obj:`~eddymotion.model._sklearn.EddyMotionGPR` The eddymotion Gaussian process regressor object. Returns @@ -71,7 +71,7 @@ def cross_validate( """ rkf = RepeatedKFold(n_splits=cv, n_repeats=120 // cv) - scores = cross_val_score(gpm, X, y, scoring="neg_root_mean_squared_error", cv=rkf) + scores = cross_val_score(gpr, X, y, scoring="neg_root_mean_squared_error", cv=rkf) return scores @@ -176,7 +176,7 @@ def main() -> None: a = 1.15 lambda_s = 120 alpha = 100 - gpm = EddyMotionGPR( + gpr = EddyMotionGPR( kernel=SphericalKriging(a=a, lambda_s=lambda_s), alpha=alpha, optimizer=None, @@ -186,7 +186,7 @@ def main() -> None: scores = defaultdict(list, {}) for n in args.kfold: for i in range(args.repeats): - cv_scores = -1.0 * cross_validate(X, y.T, n, gpm) + cv_scores = -1.0 * cross_validate(X, y.T, n, gpr) scores["rmse"] += cv_scores.tolist() scores["repeat"] += [i] * len(cv_scores) scores["n_folds"] += [n] * len(cv_scores) @@ -202,7 +202,7 @@ def main() -> None: print(grouped[["rmse"]].std()) cv = KFold(n_splits=3, shuffle=False, random_state=None) - predictions = cross_val_predict(gpm, X, y.T, cv=cv) + predictions = cross_val_predict(gpr, X, y.T, cv=cv) testsims.serialize_dwi(predictions.T, args.dwi_pred_data_fname)