Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

ENH: Add helper functions for DWI signal value visualization #243

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions src/eddymotion/viz/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import matplotlib.gridspec as gridspec
import numpy as np
from matplotlib import pyplot as plt
from scipy.spatial import ConvexHull, KDTree
from scipy.stats import pearsonr


Expand Down Expand Up @@ -112,3 +113,96 @@ def plot_correlation(x, y, title):
fig.tight_layout()

return fig, r


def calculate_sphere_pts(points, center):
"""Calculate the location of each point when it is expanded out to the sphere."""

kdtree = KDTree(points) # tree of nearest points
# d is an array of distances, i is an array of indices
d, i = kdtree.query(center, points.shape[0])
sphere_pts = np.zeros(points.shape, dtype=float)

radius = np.amax(d)
for p in range(points.shape[0]):
sphere_pts[p] = points[i[p]] * radius / d[p]
# points and the indices for where they were in the original lists
return sphere_pts, i


def compute_dmri_convex_hull(s, dirs, mask=None):
"""Compute the convex hull of the dMRI signal s."""

if mask is None:
mask = np.ones(len(dirs), dtype=bool)

# Scale the original sampling directions by the corresponding signal values
scaled_bvecs = dirs[mask] * np.asarray(s)[:, np.newaxis]

# Create the data for the convex hull: project the scaled vectors to a
# sphere
sphere_pts, sphere_idx = calculate_sphere_pts(scaled_bvecs, [0, 0, 0])

# Create the convex hull: find the right ordering of vertices for the
# triangles: ConvexHull finds the simplices of the points on the outside of
# the data set
hull = ConvexHull(sphere_pts)
triang_idx = hull.simplices # returns the list of indices for each triangle

return scaled_bvecs, sphere_idx, triang_idx


def plot_surface(scaled_vecs, sphere_idx, triang_idx, title, cmap):
"""Plot a surface."""

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")

ax.scatter3D(
scaled_vecs[:, 0], scaled_vecs[:, 1], scaled_vecs[:, 2], s=2, c="black", alpha=1.0
)

surface = ax.plot_trisurf(
scaled_vecs[sphere_idx, 0],
scaled_vecs[sphere_idx, 1],
scaled_vecs[sphere_idx, 2],
triangles=triang_idx,
cmap=cmap,
alpha=0.6,
)

ax.view_init(10, 45)
ax.set_aspect("equal", adjustable="box")
ax.set_title(title)

return fig, ax, surface


def plot_signal_data(y, ax):
"""Plot the data provided as a scatter plot"""

ax.scatter(
y[:, 0], y[:, 1], y[:, 2], color="red", marker="*", alpha=0.8, s=5, label="Original points"
)


def plot_prediction_surface(y, y_pred, S0, y_dirs, y_pred_dirs, title, cmap):
"""Plot the prediction surface obtained by computing the convex hull of the
predicted signal data, and plot the true data as a scatter plot."""

# Scale the original sampling directions by the corresponding signal values
y_bvecs = y_dirs * np.asarray(y)[:, np.newaxis]

# Compute the convex hull
y_pred_bvecs, sphere_idx, triang_idx = compute_dmri_convex_hull(y_pred, y_pred_dirs)

# Plot the surface
fig, ax, surface = plot_surface(y_pred_bvecs, sphere_idx, triang_idx, title, cmap)

# Add the underlying signal to the plot
# plot_signal_data(y_bvecs/S0, ax)
plot_signal_data(y_bvecs, ax)

fig.tight_layout()

return fig, ax, surface