Skip to content

Commit

Permalink
BUG: Provide missing position arg to local cross validation function
Browse files Browse the repository at this point in the history
Provide missing position argument to local cross validation function in
GP estimation error analysis script.

Fixes
```
scripts/dwi_gp_estimation_error_analysis.py:210: error:
 Missing positional argument "gpr" in call to "cross_validate"  [call-arg]
scripts/dwi_gp_estimation_error_analysis.py:210: error:
 Unsupported operand types for * ("float" and "dict[int, list[tuple[ndarray[Any, Any], ndarray[Any, Any], ndarray[Any, Any], ndarray[Any, Any]]]]")  [operator]
scripts/dwi_gp_estimation_error_analysis.py:210: error:
 Argument 4 to "cross_validate" has incompatible type "DiffusionGPR"; expected "int"  [arg-type]
scripts/dwi_gp_estimation_error_analysis.py:211: error:
 "float" has no attribute "tolist"  [attr-defined]
scripts/dwi_gp_estimation_error_analysis.py:212: error:
 Argument 1 to "len" has incompatible type "float"; expected "Sized"  [arg-type]
scripts/dwi_gp_estimation_error_analysis.py:213: error:
 Argument 1 to "len" has incompatible type "float"; expected "Sized"  [arg-type]
scripts/dwi_gp_estimation_error_analysis.py:214: error:
 Argument 1 to "len" has incompatible type "float"; expected "Sized"  [arg-type]
scripts/dwi_gp_estimation_error_analysis.py:215: error:
 Argument 1 to "len" has incompatible type "float"; expected "Sized"  [arg-type]
```

raised for example in:
https://github.com/nipreps/nifreeze/actions/runs/12437972140/job/34728973936#step:8:113
  • Loading branch information
jhlegarreta committed Dec 22, 2024
1 parent 5c3e7d3 commit d9dafbe
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion scripts/dwi_gp_estimation_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,14 @@ def main() -> None:
# max_iter=2e5,
)

n_repeats = 10

if args.kfold:
# Use Scikit-learn cross validation
scores = defaultdict(list, {})
for n in args.kfold:
for i in range(args.repeats):
cv_scores = -1.0 * cross_validate(X, y.T, n, gpr)
cv_scores = -1.0 * cross_validate(X, y.T, n, n_repeats, gpr)
scores["rmse"] += cv_scores.tolist()
scores["repeat"] += [i] * len(cv_scores)
scores["n_folds"] += [n] * len(cv_scores)
Expand Down

0 comments on commit d9dafbe

Please sign in to comment.