From 8d167f938969c06b19db113ac078d997422eef5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Wed, 25 Sep 2024 20:44:30 -0400 Subject: [PATCH] ENH: Add a script to plot the signal estimated by the GP Add a script to plot the signal estimated by the GP as well as the error data generated by the error analysis script. Modify the signal visualization error plotting method to optionally accept the color the figure size parameters. Add methods to the the signal simulation module in order to serialize the dMRI data. Refactor the the the signal simulation module to: - Allow the dMRI signal generation method to generate evals randomly if not provided. - Allow reusing the polar random angle generation utility. - Allow the single tensor method to accept a random generator for the sake of reproducibility. - Set the `zip` function `strict` parameter to `True` as we want all iterables to have the same length. Modify the error analysis script to: - Reuse the `EddyMotionGPR` instance: factor it out from the CV function, as the instance does not change across folds and repeats. - Save the simulated signal and gtab. - Predict and save the signal of the GP estimation. - Save the simulated SNR to the CV scores data file. Since `None` indicates no noise, modify the `pandas` serialization method arguments so that `None` is not considered as a missing value. Take advantage of the commit to rename the `evals1` argument to `evals` in the error analysis script. --- scripts/dwi_estimation_error_analysis.py | 84 ++++++++---- scripts/dwi_estimation_plot.py | 160 +++++++++++++++++++++++ src/eddymotion/testing/simulations.py | 124 ++++++++++++++++-- src/eddymotion/viz/signals.py | 115 +++++++++++++++- 4 files changed, 445 insertions(+), 38 deletions(-) create mode 100644 scripts/dwi_estimation_plot.py diff --git a/scripts/dwi_estimation_error_analysis.py b/scripts/dwi_estimation_error_analysis.py index d6d739e1..10f70464 100644 --- a/scripts/dwi_estimation_error_analysis.py +++ b/scripts/dwi_estimation_error_analysis.py @@ -30,11 +30,11 @@ import argparse from collections import defaultdict +from pathlib import Path -# import nibabel as nib import numpy as np import pandas as pd -from sklearn.model_selection import RepeatedKFold, cross_val_score +from sklearn.model_selection import KFold, RepeatedKFold, cross_val_predict, cross_val_score from eddymotion.model._sklearn import ( EddyMotionGPR, @@ -47,24 +47,21 @@ def cross_validate( X: np.ndarray, y: np.ndarray, cv: int, + gpm: EddyMotionGPR, ) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]: """ Perform the experiment by estimating the dMRI signal using a Gaussian process model. Parameters ---------- - gtab : :obj:`~dipy.core.gradients.gradient_table` - Gradient table. - S0 : :obj:`float` - S0 value. - evals1 : :obj:`~numpy.ndarray` - Eigenvalues of the tensor. - evecs : :obj:`~numpy.ndarray` - Eigenvectors of the tensor. - snr : :obj:`float` - Signal-to-noise ratio. + X : :obj:`~numpy.ndarray` + Diffusion-encoding gradient vectors. + y : :obj:`~numpy.ndarray` + DWI signal. cv : :obj:`int` number of folds + gpm : obj:`~eddymotion.model._sklearn.EddyMotionGPR` + The eddymotion Gaussian process regressor object. Returns ------- @@ -72,11 +69,6 @@ def cross_validate( Data for the predicted signal and its error. """ - gpm = EddyMotionGPR( - kernel=SphericalKriging(a=1.15, lambda_s=120), - alpha=100, - optimizer=None, - ) rkf = RepeatedKFold(n_splits=cv, n_repeats=120 // cv) scores = cross_val_score(gpm, X, y, scoring="neg_root_mean_squared_error", cv=rkf) @@ -103,7 +95,32 @@ def _build_arg_parser() -> argparse.ArgumentParser: ) parser.add_argument("bval_shell", help="Shell b-value", type=float) parser.add_argument("S0", help="S0 value", type=float) - parser.add_argument("--evals1", help="Eigenvalues of the tensor", nargs="+", type=float) + parser.add_argument( + "error_data_fname", + help="Filename of TSV file containing the data to plot", + type=Path, + ) + parser.add_argument( + "dwi_gt_data_fname", + help="Filename of NIfTI file containing the generated DWI signal", + type=Path, + ) + parser.add_argument( + "bval_data_fname", + help="Filename of b-val file containing the diffusion-encoding gradient b-vals", + type=Path, + ) + parser.add_argument( + "bvec_data_fname", + help="Filename of b-vecs file containing the diffusion-encoding gradient b-vecs", + type=Path, + ) + parser.add_argument( + "dwi_pred_data_fname", + help="Filename of NIfTI file containing the predicted DWI signal", + type=Path, + ) + parser.add_argument("--evals", help="Eigenvalues of the tensor", nargs="+", type=float) parser.add_argument("--snr", help="Signal to noise ratio", type=float) parser.add_argument("--repeats", help="Number of repeats", type=int, default=5) parser.add_argument( @@ -134,37 +151,60 @@ def main() -> None: parser = _build_arg_parser() args = _parse_args(parser) + n_voxels = 100 + data, gtab = testsims.simulate_voxels( args.S0, - args.evals1, args.hsph_dirs, bval_shell=args.bval_shell, snr=args.snr, - n_voxels=100, + n_voxels=n_voxels, + evals=args.evals, seed=None, ) + # Save the generated signal and gradient table + testsims.serialize_dmri( + data, gtab, args.dwi_gt_data_fname, args.bval_data_fname, args.bvec_data_fname + ) + X = gtab[~gtab.b0s_mask].bvecs y = data[:, ~gtab.b0s_mask] + snr_str = args.snr if args.snr is not None else "None" + + a = 1.15 + lambda_s = 120 + alpha = 100 + gpm = EddyMotionGPR( + kernel=SphericalKriging(a=a, lambda_s=lambda_s), + alpha=alpha, + optimizer=None, + ) + # Use Scikit-learn cross validation scores = defaultdict(list, {}) for n in args.kfold: for i in range(args.repeats): - cv_scores = -1.0 * cross_validate(X, y.T, n) + cv_scores = -1.0 * cross_validate(X, y.T, n, gpm) scores["rmse"] += cv_scores.tolist() scores["repeat"] += [i] * len(cv_scores) scores["n_folds"] += [n] * len(cv_scores) + scores["snr"] += [snr_str] * len(cv_scores) print(f"Finished {n}-fold cross-validation") scores_df = pd.DataFrame(scores) - scores_df.to_csv("cv_scores.tsv", sep="\t", index=None, na_rep="n/a") + scores_df.to_csv(args.error_data_fname, sep="\t", index=None, na_rep="n/a") grouped = scores_df.groupby(["n_folds"]) print(grouped[["rmse"]].mean()) print(grouped[["rmse"]].std()) + cv = KFold(n_splits=3, shuffle=False, random_state=None) + predictions = cross_val_predict(gpm, X, y.T, cv=cv) + testsims.serialize_dwi(predictions.T, args.dwi_pred_data_fname) + if __name__ == "__main__": main() diff --git a/scripts/dwi_estimation_plot.py b/scripts/dwi_estimation_plot.py new file mode 100644 index 00000000..9c8400c9 --- /dev/null +++ b/scripts/dwi_estimation_plot.py @@ -0,0 +1,160 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# + +""" +Plot the RMSE (mean and std dev) and prediction surface from the predicted DWI +signal estimated using Gaussian processes k-fold cross-validation. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import nibabel as nib +import numpy as np +import pandas as pd +from dipy.core.gradients import gradient_table +from dipy.io import read_bvals_bvecs + +from eddymotion.viz.signals import plot_error, plot_prediction_surface + + +def _build_arg_parser() -> argparse.ArgumentParser: + """ + Build argument parser for command-line interface. + + Returns + ------- + :obj:`~argparse.ArgumentParser` + Argument parser for the script. + + """ + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument( + "error_data_fname", + help="Filename of TSV file containing the error data to plot", + type=Path, + ) + parser.add_argument( + "dwi_gt_data_fname", + help="Filename of NIfTI file containing the ground truth DWI signal", + type=Path, + ) + parser.add_argument( + "bval_data_fname", + help="Filename of b-val file containing the diffusion-encoding gradient b-vals", + type=Path, + ) + parser.add_argument( + "bvec_data_fname", + help="Filename of b-vecs file containing the diffusion-encoding gradient b-vecs", + type=Path, + ) + parser.add_argument( + "dwi_pred_data_fname", + help="Filename of NIfTI file containing the predicted DWI signal", + type=Path, + ) + parser.add_argument( + "error_plot_fname", + help="Filename of SVG file where the error plot will be saved", + type=Path, + ) + parser.add_argument( + "signal_surface_plot_fname", + help="Filename of SVG file where the predicted signal plot will be saved", + type=Path, + ) + return parser + + +def _parse_args(parser: argparse.ArgumentParser) -> argparse.Namespace: + """ + Parse command-line arguments. + + Parameters + ---------- + parser : :obj:`~argparse.ArgumentParser` + Argument parser for the script. + + Returns + ------- + :obj:`~argparse.Namespace` + Parsed arguments. + """ + return parser.parse_args() + + +def main() -> None: + """Main function for running the experiment and plotting the results.""" + parser = _build_arg_parser() + args = _parse_args(parser) + + df = pd.read_csv(args.error_data_fname, sep="\t", keep_default_na=False, na_values="n/a") + + # Plot the prediction error + kfolds = sorted(np.unique(df["n_folds"].values)) + snr = np.unique(df["snr"].values).item() + rmse_data = [df.groupby("n_folds").get_group(k)["rmse"].values for k in kfolds] + axis = 1 + mean = np.mean(rmse_data, axis=axis) + std_dev = np.std(rmse_data, axis=axis) + xlabel = "k" + ylabel = "RMSE" + title = f"Gaussian process estimation\n(SNR={snr})" + fig = plot_error(kfolds, mean, std_dev, xlabel, ylabel, title) + fig.savefig(args.error_plot_fname) + plt.close(fig) + + # Plot the predicted DWI signal at a single voxel + + # Load the dMRI data + signal = nib.load(args.dwi_gt_data_fname).get_fdata() + y_pred = nib.load(args.dwi_pred_data_fname).get_fdata() + + bvals, bvecs = read_bvals_bvecs(str(args.bval_data_fname), str(args.bvec_data_fname)) + gtab = gradient_table(bvals, bvecs) + + # Pick one voxel randomly + rng = np.random.default_rng(1234) + idx = rng.integers(0, signal.shape[0], size=1).item() + + title = "GP model signal prediction" + fig, _, _ = plot_prediction_surface( + signal[idx, ~gtab.b0s_mask], + y_pred[idx], + signal[idx, gtab.b0s_mask].item(), + gtab[~gtab.b0s_mask].bvecs, + gtab[~gtab.b0s_mask].bvecs, + title, + "gray", + ) + fig.savefig(args.signal_surface_plot_fname, format="svg") + + +if __name__ == "__main__": + main() diff --git a/src/eddymotion/testing/simulations.py b/src/eddymotion/testing/simulations.py index 2c1aa11e..0db29a0d 100644 --- a/src/eddymotion/testing/simulations.py +++ b/src/eddymotion/testing/simulations.py @@ -24,13 +24,17 @@ from __future__ import annotations -# import nibabel as nib +import nibabel as nib import numpy as np from dipy.core.geometry import sphere2cart from dipy.core.gradients import gradient_table from dipy.core.sphere import HemiSphere, Sphere, disperse_charges from dipy.sims.voxel import all_tensor_evecs, single_tensor +# Set according to Canales-Rodriguez, NIMG 184 2019, https://doi.org/10.1016/j.neuroimage.2018.08.071 +BOUNDS_LAMBDA1: tuple[float, float] = (1.4e-3, 1.8e-3) +BOUNDS_LAMBDA23: tuple[float, float] = (0.1e-3, 0.5e-3) + def add_b0(bvals: np.ndarray, bvecs: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ @@ -197,31 +201,129 @@ def get_query_vectors( return gtab.bvecs[idx], np.where(idx)[0] -def single_fiber_voxel(gtab, S0, evals, theta=0, phi=0, snr=20): +def single_fiber_voxel(gtab, S0, evals, rng, theta=0, phi=0, snr=20): # create eigenvectors for a single fiber evecs = create_single_fiber_evecs(theta=theta, phi=phi) # Generate some data - return single_tensor(gtab, S0=S0, evals=evals, evecs=evecs, snr=snr) + return single_tensor(gtab, S0=S0, evals=evals, evecs=evecs, snr=snr, rng=rng) + + +def create_random_polar_angles(size, rng): + """Create polar angles drawn from a uniform distribution.""" + + return zip( + rng.uniform(0, np.pi, size=size), + rng.uniform(0, 2.0 * np.pi, size=size), + strict=True, + ) + + +def create_random_evals(size, rng): + """Create eigenvalus drawn from a uniform distribution.""" + + # lambda_2 = lambda_3 set according to Canales-Rodriguez, NIMG 184 2019, + # https://doi.org/10.1016/j.neuroimage.2018.08.071 + return zip( + rng.uniform(*BOUNDS_LAMBDA1, size=size), + *[rng.uniform(*BOUNDS_LAMBDA23, size=size)] * 2, + strict=True, + ) + +def group_values(values, group_size): + return np.asarray([values[i : i + group_size] for i in range(0, len(values), group_size)]) -def simulate_voxels(S0, evals, hsph_dirs, bval_shell=1000, snr=20, n_voxels=1, seed=None): + +def simulate_voxels(S0, hsph_dirs, bval_shell=1000, snr=20, n_voxels=1, evals=None, seed=None): # Create a gradient table for a single-shell gtab = create_single_shell_gradient_table(hsph_dirs, bval_shell) rng = np.random.default_rng(seed) - angles = zip( - rng.uniform(0, np.pi, size=n_voxels), - rng.uniform(0, 2.0 * np.pi, size=n_voxels), - strict=False, - ) + angles = create_random_polar_angles(n_voxels, rng) + if evals is None: + _evals = create_random_evals(n_voxels, rng) + else: + _evals = group_values(evals, 3) + if _evals.shape[0] == 1 and n_voxels != 1: + _evals = np.repeat(_evals, n_voxels, axis=0) signal = np.vstack( [ - single_fiber_voxel(gtab, S0, evals, theta=theta, phi=phi, snr=snr) - for theta, phi in angles + single_fiber_voxel(gtab, S0, _eignvls, rng, theta=theta, phi=phi, snr=snr) + for (theta, phi), _eignvls in zip(angles, _evals, strict=True) ] ) return signal, gtab + + +def serialize_dwi(dwi_data, dwi_data_fname, affine: np.ndarray | None = None): + """Serialize DWI data. + + Parameters + ---------- + dwi_data : :obj:`~numpy.ndarray` + DWI data. + dwi_data_fname : :obj:`str` + Filename of NIfTI file to save the DWI signal. + affine : :obj:`~numpy.ndarray`, optional + Affine matrix. If ``None`` an identity affine matrix is used. + """ + + if affine is None: + affine = np.eye(4) + + dwi_img = nib.Nifti1Image(dwi_data, affine=affine) + nib.save(dwi_img, dwi_data_fname) + + +def serialize_gtab(gtab, bval_data_fname, bvec_data_fname): + """Serialize dMRI gradient-encoding table data into a pair of b-vals and + b-vecs files. + + Parameters + ---------- + gtab : :obj:`~dipy.core.gradients.gradient_table` + Gradient table. + bval_data_fname : :obj:`str` + Filename of NIfTI file to save the diffusion-encoding gradient b-vals. + bvec_data_fname : :obj:`str` + Filename of NIfTI file to save the diffusion-encoding gradient b-vecs. + """ + + fmt = "%d" + np.savetxt(bval_data_fname, gtab.bvals, newline=" ", fmt=fmt) + fmt = "%.3f" + np.savetxt(bvec_data_fname, gtab.bvecs.T, fmt=fmt) + + +def serialize_dmri( + dwi_data, + gtab, + dwi_data_fname, + bval_data_fname, + bvec_data_fname, + affine: np.ndarray | None = None, +): + """Serialize dMRI data. + + Parameters + ---------- + dwi_data : :obj:`~numpy.ndarray` + DWI data. + gtab : :obj:`~dipy.core.gradients.gradient_table` + Gradient table. + dwi_data_fname : :obj:`str` + Filename of NIfTI file to save the DWI signal. + bval_data_fname : :obj:`str` + Filename of NIfTI file to save the diffusion-encoding gradient b-vals. + bvec_data_fname : :obj:`str` + Filename of NIfTI file to save the diffusion-encoding gradient b-vecs. + affine : :obj:`~numpy.ndarray`, optional + Affine matrix. If ``None`` an identity affine matrix is used. + """ + + serialize_dwi(dwi_data, dwi_data_fname, affine=affine) + serialize_gtab(gtab, bval_data_fname, bvec_data_fname) diff --git a/src/eddymotion/viz/signals.py b/src/eddymotion/viz/signals.py index 60c3933c..c763f1f2 100644 --- a/src/eddymotion/viz/signals.py +++ b/src/eddymotion/viz/signals.py @@ -25,11 +25,19 @@ import matplotlib.gridspec as gridspec import numpy as np from matplotlib import pyplot as plt +from scipy.spatial import ConvexHull, KDTree from scipy.stats import pearsonr def plot_error( - kfolds: list[int], mean: np.ndarray, std_dev: np.ndarray, xlabel: str, ylabel: str, title: str + kfolds: list[int], + mean: np.ndarray, + std_dev: np.ndarray, + xlabel: str, + ylabel: str, + title: str, + color: str = "orange", + figsize: tuple[int, int] = (19.2, 10.8), ) -> plt.Figure: """ Plot the error and standard deviation. @@ -48,6 +56,10 @@ def plot_error( Y-axis label. title : :obj:`str` Plot title. + color : :obj:`str`, optional + Plot color. + figsize : :obj:`tuple`, optional + Figure size. Returns ------- @@ -55,10 +67,10 @@ def plot_error( Matplotlib figure object. """ - fig, ax = plt.subplots() - ax.plot(kfolds, mean, c="orange") - ax.fill_between(kfolds, mean - std_dev, mean + std_dev, alpha=0.5, color="orange") - ax.scatter(kfolds, mean, c="orange") + fig, ax = plt.subplots(figsize=figsize) + ax.plot(kfolds, mean, c=color) + ax.fill_between(kfolds, mean - std_dev, mean + std_dev, alpha=0.5, color=color) + ax.scatter(kfolds, mean, c=color) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_xticks(kfolds) @@ -112,3 +124,96 @@ def plot_correlation(x, y, title): fig.tight_layout() return fig, r + + +def calculate_sphere_pts(points, center): + """Calculate the location of each point when it is expanded out to the sphere.""" + + kdtree = KDTree(points) # tree of nearest points + # d is an array of distances, i is an array of indices + d, i = kdtree.query(center, points.shape[0]) + sphere_pts = np.zeros(points.shape, dtype=float) + + radius = np.amax(d) + for p in range(points.shape[0]): + sphere_pts[p] = points[i[p]] * radius / d[p] + # points and the indices for where they were in the original lists + return sphere_pts, i + + +def compute_dmri_convex_hull(s, dirs, mask=None): + """Compute the convex hull of the dMRI signal s.""" + + if mask is None: + mask = np.ones(len(dirs), dtype=bool) + + # Scale the original sampling directions by the corresponding signal values + scaled_bvecs = dirs[mask] * np.asarray(s)[:, np.newaxis] + + # Create the data for the convex hull: project the scaled vectors to a + # sphere + sphere_pts, sphere_idx = calculate_sphere_pts(scaled_bvecs, [0, 0, 0]) + + # Create the convex hull: find the right ordering of vertices for the + # triangles: ConvexHull finds the simplices of the points on the outside of + # the data set + hull = ConvexHull(sphere_pts) + triang_idx = hull.simplices # returns the list of indices for each triangle + + return scaled_bvecs, sphere_idx, triang_idx + + +def plot_surface(scaled_vecs, sphere_idx, triang_idx, title, cmap): + """Plot a surface.""" + + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + + ax.scatter3D( + scaled_vecs[:, 0], scaled_vecs[:, 1], scaled_vecs[:, 2], s=2, c="black", alpha=1.0 + ) + + surface = ax.plot_trisurf( + scaled_vecs[sphere_idx, 0], + scaled_vecs[sphere_idx, 1], + scaled_vecs[sphere_idx, 2], + triangles=triang_idx, + cmap=cmap, + alpha=0.6, + ) + + ax.view_init(10, 45) + ax.set_aspect("equal", adjustable="box") + ax.set_title(title) + + return fig, ax, surface + + +def plot_signal_data(y, ax): + """Plot the data provided as a scatter plot""" + + ax.scatter( + y[:, 0], y[:, 1], y[:, 2], color="red", marker="*", alpha=0.8, s=5, label="Original points" + ) + + +def plot_prediction_surface(y, y_pred, S0, y_dirs, y_pred_dirs, title, cmap): + """Plot the prediction surface obtained by computing the convex hull of the + predicted signal data, and plot the true data as a scatter plot.""" + + # Scale the original sampling directions by the corresponding signal values + y_bvecs = y_dirs * np.asarray(y)[:, np.newaxis] + + # Compute the convex hull + y_pred_bvecs, sphere_idx, triang_idx = compute_dmri_convex_hull(y_pred, y_pred_dirs) + + # Plot the surface + fig, ax, surface = plot_surface(y_pred_bvecs, sphere_idx, triang_idx, title, cmap) + + # Add the underlying signal to the plot + # plot_signal_data(y_bvecs/S0, ax) + plot_signal_data(y_bvecs, ax) + + fig.tight_layout() + + return fig, ax, surface