From d9dafbe197530acff4aae9ae0cf98a3bc6232ac1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sat, 21 Dec 2024 20:36:34 -0500 Subject: [PATCH] BUG: Provide missing position arg to local cross validation function Provide missing position argument to local cross validation function in GP estimation error analysis script. Fixes ``` scripts/dwi_gp_estimation_error_analysis.py:210: error: Missing positional argument "gpr" in call to "cross_validate" [call-arg] scripts/dwi_gp_estimation_error_analysis.py:210: error: Unsupported operand types for * ("float" and "dict[int, list[tuple[ndarray[Any, Any], ndarray[Any, Any], ndarray[Any, Any], ndarray[Any, Any]]]]") [operator] scripts/dwi_gp_estimation_error_analysis.py:210: error: Argument 4 to "cross_validate" has incompatible type "DiffusionGPR"; expected "int" [arg-type] scripts/dwi_gp_estimation_error_analysis.py:211: error: "float" has no attribute "tolist" [attr-defined] scripts/dwi_gp_estimation_error_analysis.py:212: error: Argument 1 to "len" has incompatible type "float"; expected "Sized" [arg-type] scripts/dwi_gp_estimation_error_analysis.py:213: error: Argument 1 to "len" has incompatible type "float"; expected "Sized" [arg-type] scripts/dwi_gp_estimation_error_analysis.py:214: error: Argument 1 to "len" has incompatible type "float"; expected "Sized" [arg-type] scripts/dwi_gp_estimation_error_analysis.py:215: error: Argument 1 to "len" has incompatible type "float"; expected "Sized" [arg-type] ``` raised for example in: https://github.com/nipreps/nifreeze/actions/runs/12437972140/job/34728973936#step:8:113 --- scripts/dwi_gp_estimation_error_analysis.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index a4b33d5..936931a 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -202,12 +202,14 @@ def main() -> None: # max_iter=2e5, ) + n_repeats = 10 + if args.kfold: # Use Scikit-learn cross validation scores = defaultdict(list, {}) for n in args.kfold: for i in range(args.repeats): - cv_scores = -1.0 * cross_validate(X, y.T, n, gpr) + cv_scores = -1.0 * cross_validate(X, y.T, n, n_repeats, gpr) scores["rmse"] += cv_scores.tolist() scores["repeat"] += [i] * len(cv_scores) scores["n_folds"] += [n] * len(cv_scores)