From b1fb2bbd6061e0cbd009b754c932cbc78fe025a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Tue, 4 Jun 2024 19:33:18 -0400 Subject: [PATCH] ENH: Add gaussian process DWI signal representation notebooks Add gaussian process DWI signal representation notebooks: - One of the notebooks uses a simulated DWI signal. - The second notebook uses a real DWI signal. --- docs/notebooks/dwi_gp.ipynb | 186 ++++++++++++++++++++ docs/notebooks/dwi_simulated_gp.ipynb | 242 ++++++++++++++++++++++++++ 2 files changed, 428 insertions(+) create mode 100644 docs/notebooks/dwi_gp.ipynb create mode 100644 docs/notebooks/dwi_simulated_gp.ipynb diff --git a/docs/notebooks/dwi_gp.ipynb b/docs/notebooks/dwi_gp.ipynb new file mode 100644 index 00000000..0f91ecd8 --- /dev/null +++ b/docs/notebooks/dwi_gp.ipynb @@ -0,0 +1,186 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": "Gaussian process notebook", + "id": "486923b289155658" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "import tempfile\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel\n", + "\n", + "from eddymotion import model\n", + "from eddymotion.data.dmri import DWI\n", + "from eddymotion.data.splitting import lovo_split\n", + "\n", + "datadir = Path(\"../../test\") # Adapt to your local path or download to a temp location using wget\n", + "\n", + "kernel = DotProduct() + WhiteKernel()\n", + "\n", + "dwi = DWI.from_filename(datadir / \"dwi.h5\")\n", + "\n", + "_dwi_data = dwi.dataobj\n", + "# Use a subset of the data for now to see that something is written to the\n", + "# output\n", + "# bvecs = dwi.gradients[:3, :].T\n", + "bvecs = dwi.gradients[:3, 10:13].T # b0 values have already been masked\n", + "# bvals = dwi.gradients[3:, 10:13].T # Only for inspection purposes: [[1005.], [1000.], [ 995.]]\n", + "dwi_data = _dwi_data[60:63, 60:64, 40:45, 10:13]\n", + "\n", + "# ToDo\n", + "# Provide proper values/estimates for these\n", + "a = 1\n", + "h = 1 # should be a NIfTI image\n", + "\n", + "num_iterations = 5\n", + "gp = model.GaussianProcessModel(\n", + " dwi=dwi, a=a, h=h, kernel=kernel, num_iterations=num_iterations\n", + ")\n", + "indices = list(range(bvecs.shape[0]))\n", + "# ToDo\n", + "# This should be done within the GP model class\n", + "# Apply lovo strategy properly\n", + "# Vectorize and parallelize\n", + "result_mean = np.zeros_like(dwi_data)\n", + "result_stddev = np.zeros_like(dwi_data)\n", + "for idx in indices:\n", + " lovo_idx = np.ones(len(indices), dtype=bool)\n", + " lovo_idx[idx] = False\n", + " X = bvecs[lovo_idx]\n", + " for i in range(dwi_data.shape[0]):\n", + " for j in range(dwi_data.shape[1]):\n", + " for k in range(dwi_data.shape[2]):\n", + " # ToDo\n", + " # Use a mask to avoid traversing background data\n", + " y = dwi_data[i, j, k, lovo_idx]\n", + " gp.fit(X, y)\n", + " pred_mean, pred_stddev = gp.predict(\n", + " bvecs[idx, :][np.newaxis]\n", + " ) # Can take multiple values X[:2, :]\n", + " result_mean[i, j, k, idx] = pred_mean.item()\n", + " result_stddev[i, j, k, idx] = pred_stddev.item()" + ], + "id": "da2274009534db61", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Plot the data", + "id": "77e77cd4c73409d3" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "from matplotlib import pyplot as plt \n", + "%matplotlib inline\n", + "\n", + "s = dwi_data[1, 1, 2, :]\n", + "s_hat_mean = result_mean[1, 1, 2, :]\n", + "s_hat_stddev = result_stddev[1, 1, 2, :]\n", + "x = np.asarray(indices)\n", + "\n", + "fig, ax = plt.subplots()\n", + "ax.plot(x, s_hat_mean, c=\"orange\", label=\"predicted\")\n", + "plt.fill_between(\n", + " x.ravel(),\n", + " s_hat_mean - 1.96 * s_hat_stddev,\n", + " s_hat_mean + 1.96 * s_hat_stddev,\n", + " alpha=0.5,\n", + " color=\"orange\",\n", + " label=r\"95% confidence interval\",\n", + ")\n", + "plt.scatter(x, s, c=\"b\", label=\"ground truth\")\n", + "ax.set_xlabel(\"bvec indices\")\n", + "ax.set_ylabel(\"signal\")\n", + "ax.legend()\n", + "plt.title(\"Gaussian process regression on dataset\")\n", + "\n", + "plt.show()" + ], + "id": "4e51f22890fb045a", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "Plot the DWI signal for a given voxel\n", + "Compute the DWI signal value wrt the b0 (how much larger/smaller is and add that delta to the unit sphere?) for each bvec direction and plot that?" + ], + "id": "694a4c075457425d" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "# from mpl_toolkits.mplot3d import Axes3D\n", + "# fig, ax = plt.subplots()\n", + "# ax = fig.add_subplot(111, projection='3d')\n", + "# plt.scatter(xx, yy, zz)" + ], + "id": "bb7d2aef53ac99f0", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Plot the DWI signal brain data\n", + "id": "62d7bc609b65c7cf" + }, + { + "metadata": {}, + "cell_type": "code", + "source": "# plot_dwi(dmri_dataset.dataobj, dmri_dataset.affine, gradient=data_test[1], black_bg=True)", + "id": "edb0e9d255516e38", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Plot the predicted DWI signal", + "id": "1a52e2450fc61dc6" + }, + { + "metadata": {}, + "cell_type": "code", + "source": "# plot_dwi(predicted, dmri_dataset.affine, gradient=data_test[1], black_bg=True);", + "id": "66150cf337b395e0", + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/notebooks/dwi_simulated_gp.ipynb b/docs/notebooks/dwi_simulated_gp.ipynb new file mode 100644 index 00000000..7ddfe8e8 --- /dev/null +++ b/docs/notebooks/dwi_simulated_gp.ipynb @@ -0,0 +1,242 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": "We define a method below to create a noise-free DWI signal using a multi-tensor model.", + "id": "6d3d512da282ff52" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "import numpy as np\n", + "\n", + "from dipy.core.gradients import gradient_table\n", + "from dipy.core.sphere import disperse_charges, HemiSphere, Sphere\n", + "from dipy.sims.voxel import multi_tensor\n", + "from sklearn.gaussian_process import GaussianProcessRegressor\n", + "\n", + "def create_multitensor_dmri_signal(hsph_dirs):\n", + " \"\"\"Create a multi-tensor, noise-free dMRI signal for simulation purposes. It\n", + " simulates two tensors crossing at 90 degrees with equal signal fraction, and\n", + " ``hsph_dirs`` diffusion-encoding gradients at b=1000 s/mm^2, plus a b0\n", + " volume.\"\"\"\n", + "\n", + " # Eigenvalues of tensors\n", + " eval1 = [0.0015, 0.0003, 0.0003]\n", + " eval2 = [0.0015, 0.0003, 0.0003]\n", + " mevals = np.array([eval1, eval2])\n", + "\n", + " # Polar coordinates (theta, phi) of the principal axis of each tensor\n", + " angles = [(0, 0), (90, 0)]\n", + "\n", + " # Percentage of the contribution of each tensor\n", + " fractions = [50, 50]\n", + "\n", + " # Create the gradient table placing random points on a hemisphere\n", + " rng = np.random.default_rng(1234)\n", + " theta = np.pi * rng.random(hsph_dirs)\n", + " phi = 2 * np.pi * rng.random(hsph_dirs)\n", + " hsph_initial = HemiSphere(theta=theta, phi=phi)\n", + "\n", + " # Move the points so that the electrostatic potential energy is minimized\n", + " iterations = 5000\n", + " hsph_updated, potential = disperse_charges(hsph_initial, iterations)\n", + " # Create a sphere\n", + " sph = Sphere(xyz=np.vstack((hsph_updated.vertices, -hsph_updated.vertices)))\n", + "\n", + " # Create the gradients\n", + " vertices = sph.vertices\n", + " values = np.ones(vertices.shape[0])\n", + " bvecs = vertices\n", + " bval_shell1 = 1000\n", + " bvals = bval_shell1 * values\n", + "\n", + " # Add a b0 value to the gradient table\n", + " bvecs = np.insert(bvecs, 0, np.array([0, 0, 0]), axis=0)\n", + " bvals = np.insert(bvals, 0, 0)\n", + " gtab = gradient_table(bvals, bvecs)\n", + "\n", + " # Create a noise-free signal\n", + " snr = None\n", + " S0 = 100\n", + " signal, sticks = multi_tensor(\n", + " gtab, mevals, S0=S0, angles=angles, fractions=fractions, snr=snr\n", + " )\n", + "\n", + " grad = np.vstack([gtab.bvecs.T, gtab.bvals])\n", + "\n", + " return signal, sticks, grad" + ], + "id": "a0f5bab019855954", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "We now create the DWI signal using 30 directions defined on the half sphere.", + "id": "7d5b5cbebaa82e19" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "hsph_dirs = 90\n", + "signal, sticks, grad = create_multitensor_dmri_signal(hsph_dirs)" + ], + "id": "e9545781fe5cf3b8", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Since there is only a single voxel in the simulated DWI signal, we add 3 axes before the diffusion-encoding gradient axis so that the plotting method can appropriately represent it. ", + "id": "a31eef208433f772" + }, + { + "metadata": {}, + "cell_type": "code", + "execution_count": null, + "source": [ + "voxel_idx = [0, 0, 0]\n", + "dwi_data = signal[np.newaxis, np.newaxis, np.newaxis, :]" + ], + "id": "c07f103d9bd347cc", + "outputs": [] + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "We now define the kernel that we will be using for the Gaussian process", + "id": "37e6f400ed44f2d9" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "from src.eddymotion.model.kernels import SphericalCovarianceKernel\n", + "\n", + "lambda_s = 2.0\n", + "a = 1.0\n", + "sigma_sq = 0.5\n", + "kernel = SphericalCovarianceKernel(lambda_s=lambda_s, a=a, sigma_sq=sigma_sq)" + ], + "id": "a66400cc9ee4c084", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "The ``grad`` gradient table instance is in RAS+b format, so we choose the diffusion-encoding gradient vectors (leaving out the first index, which corresponds to the b0 volume) to fit the Gaussian process.", + "id": "12b7491f09d7ea74" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "_grad = grad[:3, 1:]\n", + "gpr = GaussianProcessRegressor(kernel=kernel, random_state=0)\n", + "gpr.fit(_grad.T, signal[1:])" + ], + "id": "b64f511b41456359", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Now predict the signal on the last diffusion-encoding gradient vector. ", + "id": "351b0db081b0e890" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "_grad_pred = grad[:3, -1]\n", + "y_mean, y_std = gpr.predict(_grad_pred, return_std=True)" + ], + "id": "ecfc20d16f567c91", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Check whether the hyperparameters of the kernel have been optimized", + "id": "c025c7d8a8c47b4c" + }, + { + "metadata": {}, + "cell_type": "code", + "source": "gpr.kernel_", + "id": "1ee8a8d2584ea19", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "Plot the training data and the predictions from the Gaussian process", + "id": "8913c3c3ef7d50f7" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "from matplotlib import pyplot as plt \n", + "%matplotlib inline\n", + "\n", + "s = dwi_data[voxel_idx[0], voxel_idx[1], voxel_idx[2], :]\n", + "s_hat_mean = y_mean[voxel_idx[0], voxel_idx[1], voxel_idx[2], :]\n", + "s_hat_stddev = y_std[voxel_idx[0], voxel_idx[1], voxel_idx[2], :]\n", + "x = np.asarray(range(len(grad.bvals)))\n", + "\n", + "fig, ax = plt.subplots()\n", + "ax.plot(x, s_hat_mean, c=\"orange\", label=\"predicted\")\n", + "plt.fill_between(\n", + " x.ravel(),\n", + " s_hat_mean - 1.96 * s_hat_stddev,\n", + " s_hat_mean + 1.96 * s_hat_stddev,\n", + " alpha=0.5,\n", + " color=\"orange\",\n", + " label=r\"95% confidence interval\",\n", + ")\n", + "plt.scatter(x, s, c=\"b\", label=\"ground truth\")\n", + "ax.set_xlabel(\"bvec indices\")\n", + "ax.set_ylabel(\"signal\")\n", + "ax.legend()\n", + "plt.title(\"Gaussian process regression on dataset\")\n", + "\n", + "plt.show()" + ], + "id": "58f01aeb70aff1c1", + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}