Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Merge pull request #248 from jhlegarreta/RefactorScripts
Browse files Browse the repository at this point in the history
ENH: Miscellaneous improvements to the scripts
  • Loading branch information
jhlegarreta authored Oct 29, 2024
2 parents ef91d4f + aea2df5 commit 26ef287
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 66 deletions.
45 changes: 10 additions & 35 deletions scripts/dwi_gp_estimation_error_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

import numpy as np
import pandas as pd
from sklearn.model_selection import KFold, RepeatedKFold, cross_val_predict, cross_val_score
from sklearn.model_selection import RepeatedKFold, cross_val_score

from eddymotion.model._sklearn import (
EddyMotionGPR,
Expand All @@ -47,6 +47,7 @@ def cross_validate(
X: np.ndarray,
y: np.ndarray,
cv: int,
n_repeats: int,
gpr: EddyMotionGPR,
) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]:
"""
Expand All @@ -59,7 +60,9 @@ def cross_validate(
y : :obj:`~numpy.ndarray`
DWI signal.
cv : :obj:`int`
number of folds
Number of folds.
n_repeats : :obj:`int`
Number of times the cross-validator needs to be repeated.
gpr : obj:`~eddymotion.model._sklearn.EddyMotionGPR`
The eddymotion Gaussian process regressor object.
Expand All @@ -70,7 +73,7 @@ def cross_validate(
"""

rkf = RepeatedKFold(n_splits=cv, n_repeats=120 // cv)
rkf = RepeatedKFold(n_splits=cv, n_repeats=n_repeats)
scores = cross_val_score(gpr, X, y, scoring="neg_root_mean_squared_error", cv=rkf)
return scores

Expand All @@ -93,33 +96,13 @@ def _build_arg_parser() -> argparse.ArgumentParser:
help="Number of diffusion gradient-encoding directions in the half sphere",
type=int,
)
parser.add_argument("bval_shell", help="Shell b-value", type=float)
parser.add_argument("bval_shell", help="Shell b-value", type=int)
parser.add_argument("S0", help="S0 value", type=float)
parser.add_argument(
"error_data_fname",
help="Filename of TSV file containing the data to plot",
type=Path,
)
parser.add_argument(
"dwi_gt_data_fname",
help="Filename of NIfTI file containing the generated DWI signal",
type=Path,
)
parser.add_argument(
"bval_data_fname",
help="Filename of b-val file containing the diffusion-encoding gradient b-vals",
type=Path,
)
parser.add_argument(
"bvec_data_fname",
help="Filename of b-vecs file containing the diffusion-encoding gradient b-vecs",
type=Path,
)
parser.add_argument(
"dwi_pred_data_fname",
help="Filename of NIfTI file containing the predicted DWI signal",
type=Path,
)
parser.add_argument("--evals", help="Eigenvalues of the tensor", nargs="+", type=float)
parser.add_argument("--snr", help="Signal to noise ratio", type=float)
parser.add_argument("--repeats", help="Number of repeats", type=int, default=5)
Expand Down Expand Up @@ -147,7 +130,7 @@ def _parse_args(parser: argparse.ArgumentParser) -> argparse.Namespace:


def main() -> None:
"""Main function for running the experiment and plotting the results."""
"""Main function for running the experiment."""
parser = _build_arg_parser()
args = _parse_args(parser)

Expand All @@ -163,11 +146,6 @@ def main() -> None:
seed=None,
)

# Save the generated signal and gradient table
testsims.serialize_dmri(
data, gtab, args.dwi_gt_data_fname, args.bval_data_fname, args.bvec_data_fname
)

X = gtab[~gtab.b0s_mask].bvecs
y = data[:, ~gtab.b0s_mask]

Expand All @@ -190,10 +168,11 @@ def main() -> None:
scores = defaultdict(list, {})
for n in args.kfold:
for i in range(args.repeats):
cv_scores = -1.0 * cross_validate(X, y.T, n, gpr)
cv_scores = -1.0 * cross_validate(X, y.T, n, np.max(args.kfold) // n, gpr)
scores["rmse"] += cv_scores.tolist()
scores["repeat"] += [i] * len(cv_scores)
scores["n_folds"] += [n] * len(cv_scores)
scores["bval"] += [args.bval_shell] * len(cv_scores)
scores["snr"] += [snr_str] * len(cv_scores)

print(f"Finished {n}-fold cross-validation")
Expand All @@ -205,10 +184,6 @@ def main() -> None:
print(grouped[["rmse"]].mean())
print(grouped[["rmse"]].std())

cv = KFold(n_splits=3, shuffle=False, random_state=None)
predictions = cross_val_predict(gpr, X, y.T, cv=cv)
testsims.serialize_dwi(predictions.T, args.dwi_pred_data_fname)


if __name__ == "__main__":
main()
108 changes: 108 additions & 0 deletions scripts/dwi_gp_estimation_error_analysis_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright The NiPreps Developers <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#

"""
Plot the RMSE (mean and std dev) from the predicted DWI signal estimated using
Gaussian processes k-fold cross-validation.
"""

from __future__ import annotations

import argparse
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from eddymotion.viz.signals import plot_error


def _build_arg_parser() -> argparse.ArgumentParser:
"""
Build argument parser for command-line interface.
Returns
-------
:obj:`~argparse.ArgumentParser`
Argument parser for the script.
"""
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"error_data_fname",
help="Filename of TSV file containing the error data to plot",
type=Path,
)
parser.add_argument(
"error_plot_fname",
help="Filename of SVG file where the error plot will be saved",
type=Path,
)
return parser


def _parse_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
"""
Parse command-line arguments.
Parameters
----------
parser : :obj:`~argparse.ArgumentParser`
Argument parser for the script.
Returns
-------
:obj:`~argparse.Namespace`
Parsed arguments.
"""
return parser.parse_args()


def main() -> None:
"""Main function to plot the Gaussian Process estimation error analysis data."""
parser = _build_arg_parser()
args = _parse_args(parser)

df = pd.read_csv(args.error_data_fname, sep="\t", keep_default_na=False, na_values="n/a")

# Plot the prediction error
kfolds = sorted(np.unique(df["n_folds"].values))
snr = np.unique(df["snr"].values).item()
bval = np.unique(df["bval"].values).item()
rmse_data = [df.groupby("n_folds").get_group(k)["rmse"].values for k in kfolds]
axis = 1
mean = np.mean(rmse_data, axis=axis)
std_dev = np.std(rmse_data, axis=axis)
xlabel = "k"
ylabel = "RMSE"
title = f"Gaussian process estimation\n(b={bval} s/mm^2; SNR={snr})"
fig = plot_error(kfolds, mean, std_dev, xlabel, ylabel, title)
fig.savefig(args.error_plot_fname)
plt.close(fig)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,20 @@
#

"""
Plot the RMSE (mean and std dev) and prediction surface from the predicted DWI
signal estimated using Gaussian processes k-fold cross-validation.
Plot the predicted DWI signal estimated using Gaussian processes.
"""

from __future__ import annotations

import argparse
from pathlib import Path

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
from dipy.core.gradients import gradient_table
from dipy.io import read_bvals_bvecs

from eddymotion.viz.signals import plot_error, plot_prediction_surface
from eddymotion.viz.signals import plot_prediction_surface


def _build_arg_parser() -> argparse.ArgumentParser:
Expand All @@ -54,24 +51,19 @@ def _build_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"error_data_fname",
help="Filename of TSV file containing the error data to plot",
type=Path,
)
parser.add_argument(
"dwi_gt_data_fname",
help="Filename of NIfTI file containing the ground truth DWI signal",
type=Path,
)
parser.add_argument(
"bval_data_fname",
help="Filename of b-val file containing the diffusion-encoding gradient b-vals",
help="Filename of b-val file containing the diffusion-weighting values",
type=Path,
)
parser.add_argument(
"bvec_data_fname",
help="Filename of b-vecs file containing the diffusion-encoding gradient b-vecs",
help="Filename of b-vecs file containing the diffusion-encoding gradient directions",
type=Path,
)
parser.add_argument(
Expand All @@ -80,8 +72,9 @@ def _build_arg_parser() -> argparse.ArgumentParser:
type=Path,
)
parser.add_argument(
"error_plot_fname",
help="Filename of SVG file where the error plot will be saved",
"bvec_pred_data_fname",
help="Filename of b-vecs file containing the diffusion-encoding gradient b-vecs where "
"the prediction is done",
type=Path,
)
parser.add_argument(
Expand Down Expand Up @@ -114,22 +107,6 @@ def main() -> None:
parser = _build_arg_parser()
args = _parse_args(parser)

df = pd.read_csv(args.error_data_fname, sep="\t", keep_default_na=False, na_values="n/a")

# Plot the prediction error
kfolds = sorted(np.unique(df["n_folds"].values))
snr = np.unique(df["snr"].values).item()
rmse_data = [df.groupby("n_folds").get_group(k)["rmse"].values for k in kfolds]
axis = 1
mean = np.mean(rmse_data, axis=axis)
std_dev = np.std(rmse_data, axis=axis)
xlabel = "k"
ylabel = "RMSE"
title = f"Gaussian process estimation\n(SNR={snr})"
fig = plot_error(kfolds, mean, std_dev, xlabel, ylabel, title)
fig.savefig(args.error_plot_fname)
plt.close(fig)

# Plot the predicted DWI signal at a single voxel

# Load the dMRI data
Expand All @@ -143,13 +120,15 @@ def main() -> None:
rng = np.random.default_rng(1234)
idx = rng.integers(0, signal.shape[0], size=1).item()

dirs = np.loadtxt(args.bvec_pred_data_fname)

title = "GP model signal prediction"
fig, _, _ = plot_prediction_surface(
signal[idx, ~gtab.b0s_mask],
y_pred[idx],
signal[idx, gtab.b0s_mask].item(),
gtab[~gtab.b0s_mask].bvecs,
gtab[~gtab.b0s_mask].bvecs,
dirs.T,
title,
"gray",
)
Expand Down
Loading

0 comments on commit 26ef287

Please sign in to comment.