diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index a4b33d5..936931a 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -202,12 +202,14 @@ def main() -> None: # max_iter=2e5, ) + n_repeats = 10 + if args.kfold: # Use Scikit-learn cross validation scores = defaultdict(list, {}) for n in args.kfold: for i in range(args.repeats): - cv_scores = -1.0 * cross_validate(X, y.T, n, gpr) + cv_scores = -1.0 * cross_validate(X, y.T, n, n_repeats, gpr) scores["rmse"] += cv_scores.tolist() scores["repeat"] += [i] * len(cv_scores) scores["n_folds"] += [n] * len(cv_scores)