Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
ENH: Add a script to plot the signal estimated by the GP
Browse files Browse the repository at this point in the history
Add a script to plot the signal estimated by the GP.
  • Loading branch information
jhlegarreta committed Sep 29, 2024
1 parent a7880a3 commit c102c22
Showing 1 changed file with 335 additions and 0 deletions.
335 changes: 335 additions & 0 deletions scripts/dwi_estimation_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,335 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2024 The NiPreps Developers <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#

""" "
Simulate the DWI signal from a single fiber and plot the predicted signal using a Gaussian process
estimator.
"""

import argparse

import numpy as np
from dipy.core.geometry import sphere2cart
from dipy.core.gradients import gradient_table
from dipy.core.sphere import HemiSphere, Sphere, disperse_charges
from dipy.sims.voxel import all_tensor_evecs, single_tensor
from matplotlib import pyplot as plt
from scipy.spatial import ConvexHull, KDTree

from eddymotion.model._dipy import GaussianProcessModel

SAMPLING_DIRECTIONS = 200


def add_b0(bvals, bvecs):
"""Add a b0 signal to the diffusion-encoding gradient values and vectors."""

_bvals = np.insert(bvals, 0, 0)
_bvecs = np.insert(bvecs, 0, np.array([0, 0, 0]), axis=0)

return _bvals, _bvecs


def create_single_fiber_evecs():
"""Create eigenvalues for a simulated single fiber."""

# Polar coordinates (theta, phi) of the principal axis of the tensor
angles = np.array([0, 0])
sticks = np.array(sphere2cart(1, np.deg2rad(angles[0]), np.deg2rad(angles[1])))
evecs = all_tensor_evecs(sticks)

return evecs


def create_random_polar_coordinates(hsph_dirs, seed=1234):
"""Create random polar coordinate values"""

rng = np.random.default_rng(seed)
theta = np.pi * rng.random(hsph_dirs)
phi = 2 * np.pi * rng.random(hsph_dirs)

return theta, phi


def create_diffusion_encoding_gradient_dirs(hsph_dirs, iterations=5000, seed=1234):
"""Create the dMRI gradient-encoding directions."""

# Create the gradient-encoding directions placing random points on a hemisphere
theta, phi = create_random_polar_coordinates(hsph_dirs, seed=seed)
hsph_initial = HemiSphere(theta=theta, phi=phi)

# Move the points so that the electrostatic potential energy is minimized
hsph_updated, potential = disperse_charges(hsph_initial, iterations)

# Create a sphere
return Sphere(xyz=np.vstack((hsph_updated.vertices, -hsph_updated.vertices)))


def create_single_shell_gradient_table(hsph_dirs, bval_shell, iterations=5000):
"""Create a single-shell gradient table."""

# Create diffusion-encoding gradient directions
sph = create_diffusion_encoding_gradient_dirs(hsph_dirs, iterations=iterations)

# Create the gradient bvals and bvecs
vertices = sph.vertices
values = np.ones(vertices.shape[0])
bvecs = vertices
bvals = bval_shell * values

# Add a b0 value to the gradient table
bvals, bvecs = add_b0(bvals, bvecs)
return gradient_table(bvals, bvecs)


def get_query_vectors(gtab, train_mask):
"""Get the diffusion-encoding gradient vectors where the signal is to be estimated from the
gradient table and the training mask: the vectors of interest are those that are masked in
the training mask. b0 values are excluded."""

idx = np.logical_and(~train_mask, ~gtab.b0s_mask)
return gtab.bvecs[idx], np.where(idx)[0]


def create_random_train_mask(gtab, size, seed=1234):
"""Create a mask for the gradient table where a ``size`` number of indices will be
excluded. b0 values are excluded."""

rng = np.random.default_rng(seed)

# Get the indices of the non-zero diffusion-encoding gradient vector indices
nnzero_degv_idx = np.where(~gtab.b0s_mask)[0]

if nnzero_degv_idx.size < size:
raise ValueError(
f"Requested {size} values for masking; gradient table has {nnzero_degv_idx.size} "
"non-zero diffusion-encoding gradient vectors. Reduce the number of masked values."
)

lo = rng.choice(nnzero_degv_idx, size=size, replace=False)

# Exclude the b0s
zero_degv_idx = np.asarray(list(set(range(len(gtab.bvals))).difference(nnzero_degv_idx)))
lo = np.hstack([zero_degv_idx, lo])

train_mask = np.ones(len(gtab.bvals), dtype=bool)
train_mask[lo] = False

return train_mask


def perform_experiment(gtab, S0, evals1, evecs, snr):
"""Perform experiment: estimate the dMRI signal on a set of directions fitting a
Gaussian process to the rest of the data."""

# Fix the random number generator for reproducibility when generating the
# signal
seed = 1234
rng = np.random.default_rng(seed)

# Define the Gaussian process model parameters
kernel_model = "spherical"
lambda_s = 2.0
a = 1.0
sigma_sq = 0.5

# Define the Gaussian process model instance
gp_model = GaussianProcessModel(
kernel_model=kernel_model, lambda_s=lambda_s, a=a, sigma_sq=sigma_sq
)

# Create the DWI signal using a single tensor
signal = single_tensor(gtab, S0=S0, evals=evals1, evecs=evecs, snr=snr, rng=rng)

# Use all available data for training
gpfit = gp_model.fit(signal[~gtab.b0s_mask], gtab[~gtab.b0s_mask])

# Predict on an oversampled set of random directions over the unit sphere
# theta, phi = create_random_polar_coordinates(SAMPLING_DIRECTIONS, seed=seed)
# sph = Sphere(theta=theta, phi=phi)

# ToDo
# Not sure why all predictions are zero in gpfit.predict(sph.vertices)
# Also, when creating the convex hull, the gtab required is the one that
# would correspond to the new directions, so a new gtab would need to be
# generated
# return signal, gpfit.predict(sph.vertices), sph.vertices
# For now, predict on the same data
return signal, gpfit.predict(gtab[~gtab.b0s_mask].bvecs), gtab[~gtab.b0s_mask].bvecs


def calculate_sphere_pts(points, center):
"""Calculate the location of each point when it is expanded out to the sphere."""

kdtree = KDTree(points) # tree of nearest points
# d is an array of distances, i is an array of indices
d, i = kdtree.query(center, points.shape[0])
sphere_pts = np.zeros(points.shape, dtype=float)

radius = np.amax(d)
for p in range(points.shape[0]):
sphere_pts[p] = points[i[p]] * radius / d[p]
# points and the indices for where they were in the original lists
return sphere_pts, i


def compute_dmri_convex_hull(s, dirs, mask=None):
"""Compute the convex hull of the dMRI signal s."""

if mask is None:
mask = np.ones(len(dirs), dtype=bool)

# Scale the original sampling directions by the corresponding signal values
scaled_bvecs = dirs[mask] * np.asarray(s)[:, np.newaxis]

# Create the data for the convex hull: project the scaled vectors to a
# sphere
sphere_pts, sphere_idx = calculate_sphere_pts(scaled_bvecs, [0, 0, 0])

# Create the convex hull: find the right ordering of vertices for the
# triangles: ConvexHull finds the simplices of the points on the outside of
# the data set
hull = ConvexHull(sphere_pts)
triang_idx = hull.simplices # returns the list of indices for each triangle

return scaled_bvecs, sphere_idx, triang_idx


def plot_surface(scaled_vecs, sphere_idx, triang_idx, title, cmap):
"""Plot a surface."""

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")

ax.scatter3D(
scaled_vecs[:, 0], scaled_vecs[:, 1], scaled_vecs[:, 2], s=2, c="black", alpha=1.0
)

surface = ax.plot_trisurf(
scaled_vecs[sphere_idx, 0],
scaled_vecs[sphere_idx, 1],
scaled_vecs[sphere_idx, 2],
triangles=triang_idx,
cmap=cmap,
alpha=0.6,
)

ax.view_init(10, 45)
ax.set_aspect("equal", adjustable="box")
ax.set_title(title)

return fig, ax, surface


def plot_signal_data(y, ax):
"""Plot the data provided as a scatter plot"""

ax.scatter(
y[:, 0], y[:, 1], y[:, 2], color="red", marker="*", alpha=0.8, s=5, label="Original points"
)


def plot_prediction_surface(y, y_pred, S0, y_dirs, y_pred_dirs, title, cmap):
"""Plot the prediction surface obtained by computing the convex hull of the
predicted signal data, and plot the true data as a scatter plot."""

# Scale the original sampling directions by the corresponding signal values
y_bvecs = y_dirs * np.asarray(y)[:, np.newaxis]

# Compute the convex hull
y_pred_bvecs, sphere_idx, triang_idx = compute_dmri_convex_hull(y_pred, y_pred_dirs)

# Plot the surface
fig, ax, surface = plot_surface(y_pred_bvecs, sphere_idx, triang_idx, title, cmap)

# Add the underlying signal to the plot
# plot_signal_data(y_bvecs/S0, ax)
plot_signal_data(y_bvecs, ax)

fig.tight_layout()

return fig, ax, surface


def _build_arg_parser():
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"hsph_dirs",
help="Number of diffusion gradient-encoding directions in the half sphere",
type=int,
)
parser.add_argument(
"bval_shell",
help="Shell b-value",
type=float,
)
parser.add_argument(
"S0",
help="S0 value",
type=float,
)
parser.add_argument(
"--evals1",
help="Eigenvalues of the tensor",
nargs="+",
type=float,
)
parser.add_argument(
"--snr",
help="Signal to noise ratio",
type=float,
)
return parser


def _parse_args(parser):
args = parser.parse_args()

return args


def main():
parser = _build_arg_parser()
args = _parse_args(parser)

# create eigenvectors for a single fiber
evecs = create_single_fiber_evecs()

# Create a gradient table for a single-shell
gtab = create_single_shell_gradient_table(args.hsph_dirs, args.bval_shell)

# Estimate the dMRI signal using a Gaussian process estimator
y, y_pred, y_pred_dirs = perform_experiment(gtab, args.S0, args.evals1, evecs, args.snr)

# Plot the predicted signal
title = "GP model signal prediction\n(single-fiber)"
fig, _, _ = plot_prediction_surface(
y[~gtab.b0s_mask], y_pred, args.S0, gtab.bvecs[~gtab.b0s_mask], y_pred_dirs, title, "gray"
)
fig.savefig(args.gp_pred_plot_fname, format="svg")


if __name__ == "__main__":
main()

0 comments on commit c102c22

Please sign in to comment.