diff --git a/src/eddymotion/viz/signals.py b/src/eddymotion/viz/signals.py index 60c3933c..a2073137 100644 --- a/src/eddymotion/viz/signals.py +++ b/src/eddymotion/viz/signals.py @@ -25,6 +25,7 @@ import matplotlib.gridspec as gridspec import numpy as np from matplotlib import pyplot as plt +from scipy.spatial import ConvexHull, KDTree from scipy.stats import pearsonr @@ -112,3 +113,96 @@ def plot_correlation(x, y, title): fig.tight_layout() return fig, r + + +def calculate_sphere_pts(points, center): + """Calculate the location of each point when it is expanded out to the sphere.""" + + kdtree = KDTree(points) # tree of nearest points + # d is an array of distances, i is an array of indices + d, i = kdtree.query(center, points.shape[0]) + sphere_pts = np.zeros(points.shape, dtype=float) + + radius = np.amax(d) + for p in range(points.shape[0]): + sphere_pts[p] = points[i[p]] * radius / d[p] + # points and the indices for where they were in the original lists + return sphere_pts, i + + +def compute_dmri_convex_hull(s, dirs, mask=None): + """Compute the convex hull of the dMRI signal s.""" + + if mask is None: + mask = np.ones(len(dirs), dtype=bool) + + # Scale the original sampling directions by the corresponding signal values + scaled_bvecs = dirs[mask] * np.asarray(s)[:, np.newaxis] + + # Create the data for the convex hull: project the scaled vectors to a + # sphere + sphere_pts, sphere_idx = calculate_sphere_pts(scaled_bvecs, [0, 0, 0]) + + # Create the convex hull: find the right ordering of vertices for the + # triangles: ConvexHull finds the simplices of the points on the outside of + # the data set + hull = ConvexHull(sphere_pts) + triang_idx = hull.simplices # returns the list of indices for each triangle + + return scaled_bvecs, sphere_idx, triang_idx + + +def plot_surface(scaled_vecs, sphere_idx, triang_idx, title, cmap): + """Plot a surface.""" + + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + + ax.scatter3D( + scaled_vecs[:, 0], scaled_vecs[:, 1], scaled_vecs[:, 2], s=2, c="black", alpha=1.0 + ) + + surface = ax.plot_trisurf( + scaled_vecs[sphere_idx, 0], + scaled_vecs[sphere_idx, 1], + scaled_vecs[sphere_idx, 2], + triangles=triang_idx, + cmap=cmap, + alpha=0.6, + ) + + ax.view_init(10, 45) + ax.set_aspect("equal", adjustable="box") + ax.set_title(title) + + return fig, ax, surface + + +def plot_signal_data(y, ax): + """Plot the data provided as a scatter plot""" + + ax.scatter( + y[:, 0], y[:, 1], y[:, 2], color="red", marker="*", alpha=0.8, s=5, label="Original points" + ) + + +def plot_prediction_surface(y, y_pred, S0, y_dirs, y_pred_dirs, title, cmap): + """Plot the prediction surface obtained by computing the convex hull of the + predicted signal data, and plot the true data as a scatter plot.""" + + # Scale the original sampling directions by the corresponding signal values + y_bvecs = y_dirs * np.asarray(y)[:, np.newaxis] + + # Compute the convex hull + y_pred_bvecs, sphere_idx, triang_idx = compute_dmri_convex_hull(y_pred, y_pred_dirs) + + # Plot the surface + fig, ax, surface = plot_surface(y_pred_bvecs, sphere_idx, triang_idx, title, cmap) + + # Add the underlying signal to the plot + # plot_signal_data(y_bvecs/S0, ax) + plot_signal_data(y_bvecs, ax) + + fig.tight_layout() + + return fig, ax, surface