Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
ENH: Add gaussian process DWI signal representation notebooks
Browse files Browse the repository at this point in the history
Add gaussian process DWI signal representation notebooks:
- One of the notebooks uses a simulated DWI signal.
- The second notebook uses a real DWI signal.
  • Loading branch information
jhlegarreta committed Jul 1, 2024
1 parent 8c0bf36 commit f363635
Show file tree
Hide file tree
Showing 2 changed files with 473 additions and 0 deletions.
186 changes: 186 additions & 0 deletions docs/notebooks/dwi_gp.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "Gaussian process notebook",
"id": "486923b289155658"
},
{
"metadata": {},
"cell_type": "code",
"source": [
"import tempfile\n",
"from pathlib import Path\n",
"\n",
"import numpy as np\n",
"from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel\n",
"\n",
"from eddymotion import model\n",
"from eddymotion.data.dmri import DWI\n",
"from eddymotion.data.splitting import lovo_split\n",
"\n",
"datadir = Path(\"../../test\") # Adapt to your local path or download to a temp location using wget\n",
"\n",
"kernel = DotProduct() + WhiteKernel()\n",
"\n",
"dwi = DWI.from_filename(datadir / \"dwi.h5\")\n",
"\n",
"_dwi_data = dwi.dataobj\n",
"# Use a subset of the data for now to see that something is written to the\n",
"# output\n",
"# bvecs = dwi.gradients[:3, :].T\n",
"bvecs = dwi.gradients[:3, 10:13].T # b0 values have already been masked\n",
"# bvals = dwi.gradients[3:, 10:13].T # Only for inspection purposes: [[1005.], [1000.], [ 995.]]\n",
"dwi_data = _dwi_data[60:63, 60:64, 40:45, 10:13]\n",
"\n",
"# ToDo\n",
"# Provide proper values/estimates for these\n",
"a = 1\n",
"h = 1 # should be a NIfTI image\n",
"\n",
"num_iterations = 5\n",
"gp = model.GaussianProcessModel(\n",
" dwi=dwi, a=a, h=h, kernel=kernel, num_iterations=num_iterations\n",
")\n",
"indices = list(range(bvecs.shape[0]))\n",
"# ToDo\n",
"# This should be done within the GP model class\n",
"# Apply lovo strategy properly\n",
"# Vectorize and parallelize\n",
"result_mean = np.zeros_like(dwi_data)\n",
"result_stddev = np.zeros_like(dwi_data)\n",
"for idx in indices:\n",
" lovo_idx = np.ones(len(indices), dtype=bool)\n",
" lovo_idx[idx] = False\n",
" X = bvecs[lovo_idx]\n",
" for i in range(dwi_data.shape[0]):\n",
" for j in range(dwi_data.shape[1]):\n",
" for k in range(dwi_data.shape[2]):\n",
" # ToDo\n",
" # Use a mask to avoid traversing background data\n",
" y = dwi_data[i, j, k, lovo_idx]\n",
" gp.fit(X, y)\n",
" pred_mean, pred_stddev = gp.predict(\n",
" bvecs[idx, :][np.newaxis]\n",
" ) # Can take multiple values X[:2, :]\n",
" result_mean[i, j, k, idx] = pred_mean.item()\n",
" result_stddev[i, j, k, idx] = pred_stddev.item()"
],
"id": "da2274009534db61",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Plot the data",
"id": "77e77cd4c73409d3"
},
{
"metadata": {},
"cell_type": "code",
"source": [
"from matplotlib import pyplot as plt \n",
"%matplotlib inline\n",
"\n",
"s = dwi_data[1, 1, 2, :]\n",
"s_hat_mean = result_mean[1, 1, 2, :]\n",
"s_hat_stddev = result_stddev[1, 1, 2, :]\n",
"x = np.asarray(indices)\n",
"\n",
"fig, ax = plt.subplots()\n",
"ax.plot(x, s_hat_mean, c=\"orange\", label=\"predicted\")\n",
"plt.fill_between(\n",
" x.ravel(),\n",
" s_hat_mean - 1.96 * s_hat_stddev,\n",
" s_hat_mean + 1.96 * s_hat_stddev,\n",
" alpha=0.5,\n",
" color=\"orange\",\n",
" label=r\"95% confidence interval\",\n",
")\n",
"plt.scatter(x, s, c=\"b\", label=\"ground truth\")\n",
"ax.set_xlabel(\"bvec indices\")\n",
"ax.set_ylabel(\"signal\")\n",
"ax.legend()\n",
"plt.title(\"Gaussian process regression on dataset\")\n",
"\n",
"plt.show()"
],
"id": "4e51f22890fb045a",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": [
"Plot the DWI signal for a given voxel\n",
"Compute the DWI signal value wrt the b0 (how much larger/smaller is and add that delta to the unit sphere?) for each bvec direction and plot that?"
],
"id": "694a4c075457425d"
},
{
"metadata": {},
"cell_type": "code",
"source": [
"# from mpl_toolkits.mplot3d import Axes3D\n",
"# fig, ax = plt.subplots()\n",
"# ax = fig.add_subplot(111, projection='3d')\n",
"# plt.scatter(xx, yy, zz)"
],
"id": "bb7d2aef53ac99f0",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Plot the DWI signal brain data\n",
"id": "62d7bc609b65c7cf"
},
{
"metadata": {},
"cell_type": "code",
"source": "# plot_dwi(dmri_dataset.dataobj, dmri_dataset.affine, gradient=data_test[1], black_bg=True)",
"id": "edb0e9d255516e38",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Plot the predicted DWI signal",
"id": "1a52e2450fc61dc6"
},
{
"metadata": {},
"cell_type": "code",
"source": "# plot_dwi(predicted, dmri_dataset.affine, gradient=data_test[1], black_bg=True);",
"id": "66150cf337b395e0",
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading

0 comments on commit f363635

Please sign in to comment.