This repository has been archived by the owner on Dec 20, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ENH: Add a script to plot the signal estimated by the GP
Add a script to plot the signal estimated by the GP.
- Loading branch information
1 parent
a7880a3
commit c102c22
Showing
1 changed file
with
335 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,335 @@ | ||
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- | ||
# vi: set ft=python sts=4 ts=4 sw=4 et: | ||
# | ||
# Copyright 2024 The NiPreps Developers <[email protected]> | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# We support and encourage derived works from this project, please read | ||
# about our expectations at | ||
# | ||
# https://www.nipreps.org/community/licensing/ | ||
# | ||
|
||
""" " | ||
Simulate the DWI signal from a single fiber and plot the predicted signal using a Gaussian process | ||
estimator. | ||
""" | ||
|
||
import argparse | ||
|
||
import numpy as np | ||
from dipy.core.geometry import sphere2cart | ||
from dipy.core.gradients import gradient_table | ||
from dipy.core.sphere import HemiSphere, Sphere, disperse_charges | ||
from dipy.sims.voxel import all_tensor_evecs, single_tensor | ||
from matplotlib import pyplot as plt | ||
from scipy.spatial import ConvexHull, KDTree | ||
|
||
from eddymotion.model._dipy import GaussianProcessModel | ||
|
||
SAMPLING_DIRECTIONS = 200 | ||
|
||
|
||
def add_b0(bvals, bvecs): | ||
"""Add a b0 signal to the diffusion-encoding gradient values and vectors.""" | ||
|
||
_bvals = np.insert(bvals, 0, 0) | ||
_bvecs = np.insert(bvecs, 0, np.array([0, 0, 0]), axis=0) | ||
|
||
return _bvals, _bvecs | ||
|
||
|
||
def create_single_fiber_evecs(): | ||
"""Create eigenvalues for a simulated single fiber.""" | ||
|
||
# Polar coordinates (theta, phi) of the principal axis of the tensor | ||
angles = np.array([0, 0]) | ||
sticks = np.array(sphere2cart(1, np.deg2rad(angles[0]), np.deg2rad(angles[1]))) | ||
evecs = all_tensor_evecs(sticks) | ||
|
||
return evecs | ||
|
||
|
||
def create_random_polar_coordinates(hsph_dirs, seed=1234): | ||
"""Create random polar coordinate values""" | ||
|
||
rng = np.random.default_rng(seed) | ||
theta = np.pi * rng.random(hsph_dirs) | ||
phi = 2 * np.pi * rng.random(hsph_dirs) | ||
|
||
return theta, phi | ||
|
||
|
||
def create_diffusion_encoding_gradient_dirs(hsph_dirs, iterations=5000, seed=1234): | ||
"""Create the dMRI gradient-encoding directions.""" | ||
|
||
# Create the gradient-encoding directions placing random points on a hemisphere | ||
theta, phi = create_random_polar_coordinates(hsph_dirs, seed=seed) | ||
hsph_initial = HemiSphere(theta=theta, phi=phi) | ||
|
||
# Move the points so that the electrostatic potential energy is minimized | ||
hsph_updated, potential = disperse_charges(hsph_initial, iterations) | ||
|
||
# Create a sphere | ||
return Sphere(xyz=np.vstack((hsph_updated.vertices, -hsph_updated.vertices))) | ||
|
||
|
||
def create_single_shell_gradient_table(hsph_dirs, bval_shell, iterations=5000): | ||
"""Create a single-shell gradient table.""" | ||
|
||
# Create diffusion-encoding gradient directions | ||
sph = create_diffusion_encoding_gradient_dirs(hsph_dirs, iterations=iterations) | ||
|
||
# Create the gradient bvals and bvecs | ||
vertices = sph.vertices | ||
values = np.ones(vertices.shape[0]) | ||
bvecs = vertices | ||
bvals = bval_shell * values | ||
|
||
# Add a b0 value to the gradient table | ||
bvals, bvecs = add_b0(bvals, bvecs) | ||
return gradient_table(bvals, bvecs) | ||
|
||
|
||
def get_query_vectors(gtab, train_mask): | ||
"""Get the diffusion-encoding gradient vectors where the signal is to be estimated from the | ||
gradient table and the training mask: the vectors of interest are those that are masked in | ||
the training mask. b0 values are excluded.""" | ||
|
||
idx = np.logical_and(~train_mask, ~gtab.b0s_mask) | ||
return gtab.bvecs[idx], np.where(idx)[0] | ||
|
||
|
||
def create_random_train_mask(gtab, size, seed=1234): | ||
"""Create a mask for the gradient table where a ``size`` number of indices will be | ||
excluded. b0 values are excluded.""" | ||
|
||
rng = np.random.default_rng(seed) | ||
|
||
# Get the indices of the non-zero diffusion-encoding gradient vector indices | ||
nnzero_degv_idx = np.where(~gtab.b0s_mask)[0] | ||
|
||
if nnzero_degv_idx.size < size: | ||
raise ValueError( | ||
f"Requested {size} values for masking; gradient table has {nnzero_degv_idx.size} " | ||
"non-zero diffusion-encoding gradient vectors. Reduce the number of masked values." | ||
) | ||
|
||
lo = rng.choice(nnzero_degv_idx, size=size, replace=False) | ||
|
||
# Exclude the b0s | ||
zero_degv_idx = np.asarray(list(set(range(len(gtab.bvals))).difference(nnzero_degv_idx))) | ||
lo = np.hstack([zero_degv_idx, lo]) | ||
|
||
train_mask = np.ones(len(gtab.bvals), dtype=bool) | ||
train_mask[lo] = False | ||
|
||
return train_mask | ||
|
||
|
||
def perform_experiment(gtab, S0, evals1, evecs, snr): | ||
"""Perform experiment: estimate the dMRI signal on a set of directions fitting a | ||
Gaussian process to the rest of the data.""" | ||
|
||
# Fix the random number generator for reproducibility when generating the | ||
# signal | ||
seed = 1234 | ||
rng = np.random.default_rng(seed) | ||
|
||
# Define the Gaussian process model parameters | ||
kernel_model = "spherical" | ||
lambda_s = 2.0 | ||
a = 1.0 | ||
sigma_sq = 0.5 | ||
|
||
# Define the Gaussian process model instance | ||
gp_model = GaussianProcessModel( | ||
kernel_model=kernel_model, lambda_s=lambda_s, a=a, sigma_sq=sigma_sq | ||
) | ||
|
||
# Create the DWI signal using a single tensor | ||
signal = single_tensor(gtab, S0=S0, evals=evals1, evecs=evecs, snr=snr, rng=rng) | ||
|
||
# Use all available data for training | ||
gpfit = gp_model.fit(signal[~gtab.b0s_mask], gtab[~gtab.b0s_mask]) | ||
|
||
# Predict on an oversampled set of random directions over the unit sphere | ||
# theta, phi = create_random_polar_coordinates(SAMPLING_DIRECTIONS, seed=seed) | ||
# sph = Sphere(theta=theta, phi=phi) | ||
|
||
# ToDo | ||
# Not sure why all predictions are zero in gpfit.predict(sph.vertices) | ||
# Also, when creating the convex hull, the gtab required is the one that | ||
# would correspond to the new directions, so a new gtab would need to be | ||
# generated | ||
# return signal, gpfit.predict(sph.vertices), sph.vertices | ||
# For now, predict on the same data | ||
return signal, gpfit.predict(gtab[~gtab.b0s_mask].bvecs), gtab[~gtab.b0s_mask].bvecs | ||
|
||
|
||
def calculate_sphere_pts(points, center): | ||
"""Calculate the location of each point when it is expanded out to the sphere.""" | ||
|
||
kdtree = KDTree(points) # tree of nearest points | ||
# d is an array of distances, i is an array of indices | ||
d, i = kdtree.query(center, points.shape[0]) | ||
sphere_pts = np.zeros(points.shape, dtype=float) | ||
|
||
radius = np.amax(d) | ||
for p in range(points.shape[0]): | ||
sphere_pts[p] = points[i[p]] * radius / d[p] | ||
# points and the indices for where they were in the original lists | ||
return sphere_pts, i | ||
|
||
|
||
def compute_dmri_convex_hull(s, dirs, mask=None): | ||
"""Compute the convex hull of the dMRI signal s.""" | ||
|
||
if mask is None: | ||
mask = np.ones(len(dirs), dtype=bool) | ||
|
||
# Scale the original sampling directions by the corresponding signal values | ||
scaled_bvecs = dirs[mask] * np.asarray(s)[:, np.newaxis] | ||
|
||
# Create the data for the convex hull: project the scaled vectors to a | ||
# sphere | ||
sphere_pts, sphere_idx = calculate_sphere_pts(scaled_bvecs, [0, 0, 0]) | ||
|
||
# Create the convex hull: find the right ordering of vertices for the | ||
# triangles: ConvexHull finds the simplices of the points on the outside of | ||
# the data set | ||
hull = ConvexHull(sphere_pts) | ||
triang_idx = hull.simplices # returns the list of indices for each triangle | ||
|
||
return scaled_bvecs, sphere_idx, triang_idx | ||
|
||
|
||
def plot_surface(scaled_vecs, sphere_idx, triang_idx, title, cmap): | ||
"""Plot a surface.""" | ||
|
||
fig = plt.figure() | ||
ax = fig.add_subplot(111, projection="3d") | ||
|
||
ax.scatter3D( | ||
scaled_vecs[:, 0], scaled_vecs[:, 1], scaled_vecs[:, 2], s=2, c="black", alpha=1.0 | ||
) | ||
|
||
surface = ax.plot_trisurf( | ||
scaled_vecs[sphere_idx, 0], | ||
scaled_vecs[sphere_idx, 1], | ||
scaled_vecs[sphere_idx, 2], | ||
triangles=triang_idx, | ||
cmap=cmap, | ||
alpha=0.6, | ||
) | ||
|
||
ax.view_init(10, 45) | ||
ax.set_aspect("equal", adjustable="box") | ||
ax.set_title(title) | ||
|
||
return fig, ax, surface | ||
|
||
|
||
def plot_signal_data(y, ax): | ||
"""Plot the data provided as a scatter plot""" | ||
|
||
ax.scatter( | ||
y[:, 0], y[:, 1], y[:, 2], color="red", marker="*", alpha=0.8, s=5, label="Original points" | ||
) | ||
|
||
|
||
def plot_prediction_surface(y, y_pred, S0, y_dirs, y_pred_dirs, title, cmap): | ||
"""Plot the prediction surface obtained by computing the convex hull of the | ||
predicted signal data, and plot the true data as a scatter plot.""" | ||
|
||
# Scale the original sampling directions by the corresponding signal values | ||
y_bvecs = y_dirs * np.asarray(y)[:, np.newaxis] | ||
|
||
# Compute the convex hull | ||
y_pred_bvecs, sphere_idx, triang_idx = compute_dmri_convex_hull(y_pred, y_pred_dirs) | ||
|
||
# Plot the surface | ||
fig, ax, surface = plot_surface(y_pred_bvecs, sphere_idx, triang_idx, title, cmap) | ||
|
||
# Add the underlying signal to the plot | ||
# plot_signal_data(y_bvecs/S0, ax) | ||
plot_signal_data(y_bvecs, ax) | ||
|
||
fig.tight_layout() | ||
|
||
return fig, ax, surface | ||
|
||
|
||
def _build_arg_parser(): | ||
parser = argparse.ArgumentParser( | ||
description=__doc__, formatter_class=argparse.RawTextHelpFormatter | ||
) | ||
parser.add_argument( | ||
"hsph_dirs", | ||
help="Number of diffusion gradient-encoding directions in the half sphere", | ||
type=int, | ||
) | ||
parser.add_argument( | ||
"bval_shell", | ||
help="Shell b-value", | ||
type=float, | ||
) | ||
parser.add_argument( | ||
"S0", | ||
help="S0 value", | ||
type=float, | ||
) | ||
parser.add_argument( | ||
"--evals1", | ||
help="Eigenvalues of the tensor", | ||
nargs="+", | ||
type=float, | ||
) | ||
parser.add_argument( | ||
"--snr", | ||
help="Signal to noise ratio", | ||
type=float, | ||
) | ||
return parser | ||
|
||
|
||
def _parse_args(parser): | ||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
def main(): | ||
parser = _build_arg_parser() | ||
args = _parse_args(parser) | ||
|
||
# create eigenvectors for a single fiber | ||
evecs = create_single_fiber_evecs() | ||
|
||
# Create a gradient table for a single-shell | ||
gtab = create_single_shell_gradient_table(args.hsph_dirs, args.bval_shell) | ||
|
||
# Estimate the dMRI signal using a Gaussian process estimator | ||
y, y_pred, y_pred_dirs = perform_experiment(gtab, args.S0, args.evals1, evecs, args.snr) | ||
|
||
# Plot the predicted signal | ||
title = "GP model signal prediction\n(single-fiber)" | ||
fig, _, _ = plot_prediction_surface( | ||
y[~gtab.b0s_mask], y_pred, args.S0, gtab.bvecs[~gtab.b0s_mask], y_pred_dirs, title, "gray" | ||
) | ||
fig.savefig(args.gp_pred_plot_fname, format="svg") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |