From 0fcc75328577d16fa5f8fcec2d7c1fd4795cc663 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sat, 26 Oct 2024 13:09:52 -0400 Subject: [PATCH 1/3] ENH: Miscellaneous improvements to the scripts Miscellaneous improvements to the scripts: - Make the b-value parameter be an integer. - Remove storing the generated dMRI signal/gtab and predictions from the error analysis plot and constrain the script to the CV error analysis. - Make the error analysis script main method description honor its new purpose. - Store the generated signal b-value in the error analysis script so that it can be added to the plot title. - Remove plotting the predicted signal from the error analysis plot script. - Make the error analysis plot script main method description honor its purpose. - Print the generated signal b-value in the error analysis plot title. - Add a script to generate a synthetic signal and predict using a GP. Store the generated dMRI signal/gtab and the predictions. - Add a script to plot the generated synthetic signal and the GP predictions. - Rename the GP error analysis plot script so that it honors better its scope now. --- scripts/dwi_gp_estimation_error_analysis.py | 36 +--- .../dwi_gp_estimation_error_analysis_plot.py | 108 ++++++++++++ ...ot.py => dwi_gp_estimation_signal_plot.py} | 41 ++--- scripts/dwi_gp_estimation_simulated_signal.py | 164 ++++++++++++++++++ 4 files changed, 286 insertions(+), 63 deletions(-) create mode 100644 scripts/dwi_gp_estimation_error_analysis_plot.py rename scripts/{dwi_gp_estimation_analysis_plot.py => dwi_gp_estimation_signal_plot.py} (73%) create mode 100644 scripts/dwi_gp_estimation_simulated_signal.py diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index 2d3dd483..cca074c4 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -34,7 +34,7 @@ import numpy as np import pandas as pd -from sklearn.model_selection import KFold, RepeatedKFold, cross_val_predict, cross_val_score +from sklearn.model_selection import RepeatedKFold, cross_val_score from eddymotion.model._sklearn import ( EddyMotionGPR, @@ -93,33 +93,13 @@ def _build_arg_parser() -> argparse.ArgumentParser: help="Number of diffusion gradient-encoding directions in the half sphere", type=int, ) - parser.add_argument("bval_shell", help="Shell b-value", type=float) + parser.add_argument("bval_shell", help="Shell b-value", type=int) parser.add_argument("S0", help="S0 value", type=float) parser.add_argument( "error_data_fname", help="Filename of TSV file containing the data to plot", type=Path, ) - parser.add_argument( - "dwi_gt_data_fname", - help="Filename of NIfTI file containing the generated DWI signal", - type=Path, - ) - parser.add_argument( - "bval_data_fname", - help="Filename of b-val file containing the diffusion-encoding gradient b-vals", - type=Path, - ) - parser.add_argument( - "bvec_data_fname", - help="Filename of b-vecs file containing the diffusion-encoding gradient b-vecs", - type=Path, - ) - parser.add_argument( - "dwi_pred_data_fname", - help="Filename of NIfTI file containing the predicted DWI signal", - type=Path, - ) parser.add_argument("--evals", help="Eigenvalues of the tensor", nargs="+", type=float) parser.add_argument("--snr", help="Signal to noise ratio", type=float) parser.add_argument("--repeats", help="Number of repeats", type=int, default=5) @@ -147,7 +127,7 @@ def _parse_args(parser: argparse.ArgumentParser) -> argparse.Namespace: def main() -> None: - """Main function for running the experiment and plotting the results.""" + """Main function for running the experiment.""" parser = _build_arg_parser() args = _parse_args(parser) @@ -163,11 +143,6 @@ def main() -> None: seed=None, ) - # Save the generated signal and gradient table - testsims.serialize_dmri( - data, gtab, args.dwi_gt_data_fname, args.bval_data_fname, args.bvec_data_fname - ) - X = gtab[~gtab.b0s_mask].bvecs y = data[:, ~gtab.b0s_mask] @@ -190,6 +165,7 @@ def main() -> None: scores["rmse"] += cv_scores.tolist() scores["repeat"] += [i] * len(cv_scores) scores["n_folds"] += [n] * len(cv_scores) + scores["bval"] += [args.bval_shell] * len(cv_scores) scores["snr"] += [snr_str] * len(cv_scores) print(f"Finished {n}-fold cross-validation") @@ -201,10 +177,6 @@ def main() -> None: print(grouped[["rmse"]].mean()) print(grouped[["rmse"]].std()) - cv = KFold(n_splits=3, shuffle=False, random_state=None) - predictions = cross_val_predict(gpr, X, y.T, cv=cv) - testsims.serialize_dwi(predictions.T, args.dwi_pred_data_fname) - if __name__ == "__main__": main() diff --git a/scripts/dwi_gp_estimation_error_analysis_plot.py b/scripts/dwi_gp_estimation_error_analysis_plot.py new file mode 100644 index 00000000..176e56f5 --- /dev/null +++ b/scripts/dwi_gp_estimation_error_analysis_plot.py @@ -0,0 +1,108 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# + +""" +Plot the RMSE (mean and std dev) from the predicted DWI signal estimated using +Gaussian processes k-fold cross-validation. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from eddymotion.viz.signals import plot_error + + +def _build_arg_parser() -> argparse.ArgumentParser: + """ + Build argument parser for command-line interface. + + Returns + ------- + :obj:`~argparse.ArgumentParser` + Argument parser for the script. + + """ + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument( + "error_data_fname", + help="Filename of TSV file containing the error data to plot", + type=Path, + ) + parser.add_argument( + "error_plot_fname", + help="Filename of SVG file where the error plot will be saved", + type=Path, + ) + return parser + + +def _parse_args(parser: argparse.ArgumentParser) -> argparse.Namespace: + """ + Parse command-line arguments. + + Parameters + ---------- + parser : :obj:`~argparse.ArgumentParser` + Argument parser for the script. + + Returns + ------- + :obj:`~argparse.Namespace` + Parsed arguments. + """ + return parser.parse_args() + + +def main() -> None: + """Main function to plot the Gaussian Process estimation error analysis data.""" + parser = _build_arg_parser() + args = _parse_args(parser) + + df = pd.read_csv(args.error_data_fname, sep="\t", keep_default_na=False, na_values="n/a") + + # Plot the prediction error + kfolds = sorted(np.unique(df["n_folds"].values)) + snr = np.unique(df["snr"].values).item() + bval = np.unique(df["bval"].values).item() + rmse_data = [df.groupby("n_folds").get_group(k)["rmse"].values for k in kfolds] + axis = 1 + mean = np.mean(rmse_data, axis=axis) + std_dev = np.std(rmse_data, axis=axis) + xlabel = "k" + ylabel = "RMSE" + title = f"Gaussian process estimation\n(b={bval} s/mm^2; SNR={snr})" + fig = plot_error(kfolds, mean, std_dev, xlabel, ylabel, title) + fig.savefig(args.error_plot_fname) + plt.close(fig) + + +if __name__ == "__main__": + main() diff --git a/scripts/dwi_gp_estimation_analysis_plot.py b/scripts/dwi_gp_estimation_signal_plot.py similarity index 73% rename from scripts/dwi_gp_estimation_analysis_plot.py rename to scripts/dwi_gp_estimation_signal_plot.py index 9c8400c9..9148947e 100644 --- a/scripts/dwi_gp_estimation_analysis_plot.py +++ b/scripts/dwi_gp_estimation_signal_plot.py @@ -22,8 +22,7 @@ # """ -Plot the RMSE (mean and std dev) and prediction surface from the predicted DWI -signal estimated using Gaussian processes k-fold cross-validation. +Plot the predicted DWI signal estimated using Gaussian processes. """ from __future__ import annotations @@ -31,14 +30,12 @@ import argparse from pathlib import Path -import matplotlib.pyplot as plt import nibabel as nib import numpy as np -import pandas as pd from dipy.core.gradients import gradient_table from dipy.io import read_bvals_bvecs -from eddymotion.viz.signals import plot_error, plot_prediction_surface +from eddymotion.viz.signals import plot_prediction_surface def _build_arg_parser() -> argparse.ArgumentParser: @@ -54,11 +51,6 @@ def _build_arg_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawTextHelpFormatter ) - parser.add_argument( - "error_data_fname", - help="Filename of TSV file containing the error data to plot", - type=Path, - ) parser.add_argument( "dwi_gt_data_fname", help="Filename of NIfTI file containing the ground truth DWI signal", @@ -66,12 +58,12 @@ def _build_arg_parser() -> argparse.ArgumentParser: ) parser.add_argument( "bval_data_fname", - help="Filename of b-val file containing the diffusion-encoding gradient b-vals", + help="Filename of b-val file containing the diffusion-weighting values", type=Path, ) parser.add_argument( "bvec_data_fname", - help="Filename of b-vecs file containing the diffusion-encoding gradient b-vecs", + help="Filename of b-vecs file containing the diffusion-encoding gradient directions", type=Path, ) parser.add_argument( @@ -80,8 +72,9 @@ def _build_arg_parser() -> argparse.ArgumentParser: type=Path, ) parser.add_argument( - "error_plot_fname", - help="Filename of SVG file where the error plot will be saved", + "bvec_pred_data_fname", + help="Filename of b-vecs file containing the diffusion-encoding gradient b-vecs where " + "the prediction is done", type=Path, ) parser.add_argument( @@ -114,22 +107,6 @@ def main() -> None: parser = _build_arg_parser() args = _parse_args(parser) - df = pd.read_csv(args.error_data_fname, sep="\t", keep_default_na=False, na_values="n/a") - - # Plot the prediction error - kfolds = sorted(np.unique(df["n_folds"].values)) - snr = np.unique(df["snr"].values).item() - rmse_data = [df.groupby("n_folds").get_group(k)["rmse"].values for k in kfolds] - axis = 1 - mean = np.mean(rmse_data, axis=axis) - std_dev = np.std(rmse_data, axis=axis) - xlabel = "k" - ylabel = "RMSE" - title = f"Gaussian process estimation\n(SNR={snr})" - fig = plot_error(kfolds, mean, std_dev, xlabel, ylabel, title) - fig.savefig(args.error_plot_fname) - plt.close(fig) - # Plot the predicted DWI signal at a single voxel # Load the dMRI data @@ -143,13 +120,15 @@ def main() -> None: rng = np.random.default_rng(1234) idx = rng.integers(0, signal.shape[0], size=1).item() + dirs = np.loadtxt(args.bvec_pred_data_fname) + title = "GP model signal prediction" fig, _, _ = plot_prediction_surface( signal[idx, ~gtab.b0s_mask], y_pred[idx], signal[idx, gtab.b0s_mask].item(), gtab[~gtab.b0s_mask].bvecs, - gtab[~gtab.b0s_mask].bvecs, + dirs.T, title, "gray", ) diff --git a/scripts/dwi_gp_estimation_simulated_signal.py b/scripts/dwi_gp_estimation_simulated_signal.py new file mode 100644 index 00000000..3b534c68 --- /dev/null +++ b/scripts/dwi_gp_estimation_simulated_signal.py @@ -0,0 +1,164 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# + +""" +Generate a synthetic dMRI signal and estimate values using Gaussian processes. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import numpy as np +from dipy.core.sphere import Sphere + +from eddymotion.model._sklearn import EddyMotionGPR, SphericalKriging +from eddymotion.testing import simulations as testsims + +SAMPLING_DIRECTIONS = 200 + + +def _build_arg_parser() -> argparse.ArgumentParser: + """ + Build argument parser for command-line interface. + + Returns + ------- + :obj:`~argparse.ArgumentParser` + Argument parser for the script. + + """ + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument( + "hsph_dirs", + help="Number of diffusion gradient-encoding directions in the half sphere", + type=int, + ) + parser.add_argument("bval_shell", help="Shell b-value", type=int) + parser.add_argument("S0", help="S0 value", type=float) + parser.add_argument("--evals", help="Eigenvalues of the tensor", nargs="+", type=float) + parser.add_argument("--snr", help="Signal to noise ratio", type=float) + parser.add_argument( + "dwi_gt_data_fname", + help="Filename of NIfTI file containing the ground truth DWI signal", + type=Path, + ) + parser.add_argument( + "bval_data_fname", + help="Filename of b-val file containing the diffusion-weighting values", + type=Path, + ) + parser.add_argument( + "bvec_data_fname", + help="Filename of b-vecs file containing the diffusion-encoding gradient directions", + type=Path, + ) + parser.add_argument( + "dwi_pred_data_fname", + help="Filename of NIfTI file containing the predicted DWI signal", + type=Path, + ) + parser.add_argument( + "bvec_pred_data_fname", + help="Filename of b-vecs file containing the diffusion-encoding gradient b-vecs where " + "the prediction is done", + type=Path, + ) + return parser + + +def _parse_args(parser: argparse.ArgumentParser) -> argparse.Namespace: + """ + Parse command-line arguments. + + Parameters + ---------- + parser : :obj:`~argparse.ArgumentParser` + Argument parser for the script. + + Returns + ------- + :obj:`~argparse.Namespace` + Parsed arguments. + """ + return parser.parse_args() + + +def main() -> None: + """Main function for running the experiment.""" + parser = _build_arg_parser() + args = _parse_args(parser) + + seed = 1234 + n_voxels = 100 + + data, gtab = testsims.simulate_voxels( + args.S0, + args.hsph_dirs, + bval_shell=args.bval_shell, + snr=args.snr, + n_voxels=n_voxels, + evals=args.evals, + seed=seed, + ) + + # Save the generated signal and gradient table + testsims.serialize_dmri( + data, gtab, args.dwi_gt_data_fname, args.bval_data_fname, args.bvec_data_fname + ) + + # Fit the Gaussian Process regressor and predict on an arbitrary number of + # directions + a = 1.15 + lambda_s = 120 + alpha = 100 + gpr = EddyMotionGPR( + kernel=SphericalKriging(a=a, lambda_s=lambda_s), + alpha=alpha, + optimizer=None, + ) + + # Use all available data to train the GP + X_train = gtab[~gtab.b0s_mask].bvecs + y = data[:, ~gtab.b0s_mask] + + gpr_fit = gpr.fit(X_train, y.T) + + # Predict on the testing data, plus a series of random directions + theta, phi = testsims.create_random_polar_coordinates(SAMPLING_DIRECTIONS, seed=seed) + sph = Sphere(theta=theta, phi=phi) + + X_test = np.vstack([gtab[~gtab.b0s_mask].bvecs, sph.vertices]) + + predictions = gpr_fit.predict(X_test) + + # Save the predicted data + testsims.serialize_dwi(predictions.T, args.dwi_pred_data_fname) + np.savetxt(args.bvec_pred_data_fname, X_test.T, fmt="%.3f") + + +if __name__ == "__main__": + main() From 0f3e0489f4c1fb363ca1c85bb7f2564455bad5dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sat, 26 Oct 2024 13:42:41 -0400 Subject: [PATCH 2/3] ENH: Compute the number of CV repetitions dynamically Compute the number of CV repetitions dynamically by setting the number to the maximum number of folds requested by the user. --- scripts/dwi_gp_estimation_error_analysis.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index cca074c4..289a29fb 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -47,6 +47,7 @@ def cross_validate( X: np.ndarray, y: np.ndarray, cv: int, + n_repeats: int, gpr: EddyMotionGPR, ) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]: """ @@ -60,6 +61,8 @@ def cross_validate( DWI signal. cv : :obj:`int` number of folds + n_repeats : :obj:`int` + Number of times the cross-validator needs to be repeated. gpr : obj:`~eddymotion.model._sklearn.EddyMotionGPR` The eddymotion Gaussian process regressor object. @@ -70,7 +73,7 @@ def cross_validate( """ - rkf = RepeatedKFold(n_splits=cv, n_repeats=120 // cv) + rkf = RepeatedKFold(n_splits=cv, n_repeats=n_repeats) scores = cross_val_score(gpr, X, y, scoring="neg_root_mean_squared_error", cv=rkf) return scores @@ -161,7 +164,7 @@ def main() -> None: scores = defaultdict(list, {}) for n in args.kfold: for i in range(args.repeats): - cv_scores = -1.0 * cross_validate(X, y.T, n, gpr) + cv_scores = -1.0 * cross_validate(X, y.T, n, np.max(args.kfold) // n, gpr) scores["rmse"] += cv_scores.tolist() scores["repeat"] += [i] * len(cv_scores) scores["n_folds"] += [n] * len(cv_scores) From aea2df5ddf9691b501c3ac7bf21ce601c12da660 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sat, 26 Oct 2024 13:45:16 -0400 Subject: [PATCH 3/3] DOC: Improve consistency in script function docstring Improve consistency in script function docstring: use sentence style (start with capitals and end with period). --- scripts/dwi_gp_estimation_error_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/dwi_gp_estimation_error_analysis.py b/scripts/dwi_gp_estimation_error_analysis.py index 289a29fb..d910d63f 100644 --- a/scripts/dwi_gp_estimation_error_analysis.py +++ b/scripts/dwi_gp_estimation_error_analysis.py @@ -60,7 +60,7 @@ def cross_validate( y : :obj:`~numpy.ndarray` DWI signal. cv : :obj:`int` - number of folds + Number of folds. n_repeats : :obj:`int` Number of times the cross-validator needs to be repeated. gpr : obj:`~eddymotion.model._sklearn.EddyMotionGPR`