Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
ENH: Compute the number of CV repetitions dynamically
Browse files Browse the repository at this point in the history
Compute the number of CV repetitions dynamically by setting the number
to the maximum number of folds requested by the user.
  • Loading branch information
jhlegarreta committed Oct 26, 2024
1 parent 0fcc753 commit 0f3e048
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions scripts/dwi_gp_estimation_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def cross_validate(
X: np.ndarray,
y: np.ndarray,
cv: int,
n_repeats: int,
gpr: EddyMotionGPR,
) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]:
"""
Expand All @@ -60,6 +61,8 @@ def cross_validate(
DWI signal.
cv : :obj:`int`
number of folds
n_repeats : :obj:`int`
Number of times the cross-validator needs to be repeated.
gpr : obj:`~eddymotion.model._sklearn.EddyMotionGPR`
The eddymotion Gaussian process regressor object.
Expand All @@ -70,7 +73,7 @@ def cross_validate(
"""

rkf = RepeatedKFold(n_splits=cv, n_repeats=120 // cv)
rkf = RepeatedKFold(n_splits=cv, n_repeats=n_repeats)
scores = cross_val_score(gpr, X, y, scoring="neg_root_mean_squared_error", cv=rkf)
return scores

Expand Down Expand Up @@ -161,7 +164,7 @@ def main() -> None:
scores = defaultdict(list, {})
for n in args.kfold:
for i in range(args.repeats):
cv_scores = -1.0 * cross_validate(X, y.T, n, gpr)
cv_scores = -1.0 * cross_validate(X, y.T, n, np.max(args.kfold) // n, gpr)
scores["rmse"] += cv_scores.tolist()
scores["repeat"] += [i] * len(cv_scores)
scores["n_folds"] += [n] * len(cv_scores)
Expand Down

0 comments on commit 0f3e048

Please sign in to comment.