From f549baef57fad2e6bd3e19e15677d2ea06de7355 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Mon, 4 Nov 2024 19:25:50 -0500 Subject: [PATCH] ENH: Add GP estimation notebook Add GP estimation notebook. --- docs/notebooks/dwi_gp_estimation.ipynb | 237 +++++++++++++++++++++++++ 1 file changed, 237 insertions(+) create mode 100644 docs/notebooks/dwi_gp_estimation.ipynb diff --git a/docs/notebooks/dwi_gp_estimation.ipynb b/docs/notebooks/dwi_gp_estimation.ipynb new file mode 100644 index 00000000..56c76b67 --- /dev/null +++ b/docs/notebooks/dwi_gp_estimation.ipynb @@ -0,0 +1,237 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d11e5969ed6af8a5", + "metadata": {}, + "source": [ + "Estimate a DWI signal using the eddymotion Gaussian Process (GP) regressor estimator." + ] + }, + { + "cell_type": "markdown", + "id": "3476a8e9cfefd4b8", + "metadata": {}, + "source": [ + "Download the \"Sherbrooke 3-shell\" dataset using DIPY and select the b=1000 s/mm^2 shell data." + ] + }, + { + "cell_type": "code", + "id": "69a3bc6b4fbe7036", + "metadata": { + "jupyter": { + "is_executing": true + }, + "ExecuteTime": { + "start_time": "2024-11-05T12:46:54.497856Z" + } + }, + "source": [ + "import dipy.data as dpd\n", + "import nibabel as nib\n", + "import numpy as np\n", + "from dipy.core.gradients import get_bval_indices\n", + "from dipy.io import read_bvals_bvecs\n", + "from dipy.segment.mask import median_otsu\n", + "\n", + "seed = 1234\n", + "rng = np.random.default_rng(seed)\n", + "\n", + "name = \"sherbrooke_3shell\"\n", + "\n", + "dwi_fname, bval_fname, bvec_fname = dpd.get_fnames(name=name)\n", + "dwi_data = nib.load(dwi_fname).get_fdata()\n", + "bvals, bvecs = read_bvals_bvecs(bval_fname, bvec_fname)\n", + "\n", + "_, brain_mask = median_otsu(dwi_data, vol_idx=[0])\n", + "\n", + "bval = 1000\n", + "indices = get_bval_indices(bvals, bval, tol=20)\n", + "\n", + "bvecs_shell = bvecs[indices]\n", + "shell_data = dwi_data[..., indices]" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "9bd417117afaad49", + "metadata": {}, + "source": [ + "Visualize a slice of the data for a given DWI volume." + ] + }, + { + "cell_type": "code", + "id": "d8547475686958f3", + "metadata": {}, + "source": [ + "# Plot a slice\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "import numpy as np\n", + "\n", + "dwi_vol_idx = len(indices) // 2\n", + "slice_idx = list(map(int, np.divide(dwi_data.shape[:-1], 2)))\n", + "\n", + "x_slice = dwi_data[slice_idx[0], :, :, dwi_vol_idx]\n", + "y_slice = dwi_data[:, slice_idx[1], :, dwi_vol_idx]\n", + "z_slice = dwi_data[:, :, slice_idx[2], dwi_vol_idx]\n", + "slices = [x_slice, y_slice, z_slice]\n", + "\n", + "fig, axes = plt.subplots(1, len(slices))\n", + "for i, _slice in enumerate(slices):\n", + " axes[i].imshow(_slice.T, cmap=\"gray\", origin=\"lower\", aspect='equal')\n", + " axes[i].set_axis_off()\n", + "\n", + "plt.show()" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "9dcab811fe667617", + "metadata": {}, + "source": [ + "Define the EddyMotionGPR instance." + ] + }, + { + "cell_type": "code", + "id": "7d5d9562339bc849", + "metadata": {}, + "source": [ + "from eddymotion.model.gpr import EddyMotionGPR, SphericalKriging\n", + "\n", + "beta_a = 1.38\n", + "beta_l = 1 / 2.1\n", + "kernel = SphericalKriging(beta_a=beta_a, beta_l=beta_l)\n", + "\n", + "alpha = 0.1\n", + "disp = True\n", + "optimizer = None\n", + "gpr = EddyMotionGPR(kernel=kernel, alpha=alpha, disp=disp, optimizer=optimizer)\n" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "ea5cc8036fa0ab48", + "metadata": {}, + "source": [ + "Do not optimize the parameters in the fitting. " + ] + }, + { + "cell_type": "code", + "id": "7e93b99c3b072d99", + "metadata": {}, + "source": [ + "X_train = bvecs_shell\n", + "# Consider only brain voxels\n", + "dwi_mask = np.repeat(brain_mask[..., np.newaxis], shell_data.shape[-1], axis=-1)\n", + "y = shell_data[dwi_mask].reshape((X_train.shape[0], -1))\n", + "gpr.fit(X_train, y)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "dfdd82afbdb22790", + "metadata": {}, + "source": [ + "Predict on a randomly chosen direction." + ] + }, + { + "cell_type": "code", + "id": "ae3407b31b14928d", + "metadata": {}, + "source": [ + "# Pick a direction to predict\n", + "idx = rng.integers(0, len(indices))\n", + "X_test = bvecs_shell[idx][np.newaxis, :]\n", + "y_pred = gpr.predict(X_test)\n", + "\n", + "rmse = np.sqrt(np.mean(np.square(y[idx, ...] - y_pred.squeeze())))\n", + "_rmse_element = np.sqrt(np.square(y[idx, ...] - y_pred.squeeze()))\n", + "\n", + "print(f\"RMSE: {rmse}\")\n", + "threshold = 10\n", + "n_error_thr = len(_rmse_element[_rmse_element > threshold])\n", + "ratio = n_error_thr / len(_rmse_element) * 100\n", + "print(f\"Number of RMSE values above {threshold}: {n_error_thr} ({ratio:.2f}%)\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "id": "74b040c05621f2d9", + "metadata": {}, + "source": [ + "Visualize the prediction." + ] + }, + { + "cell_type": "code", + "id": "a130de2a03dff2b5", + "metadata": {}, + "source": [ + "# Reconstruct the data array\n", + "brain_mask_idx = np.where(brain_mask)\n", + "_y = np.zeros((shell_data.shape[:-1]), dtype=y.dtype)\n", + "_y[brain_mask_idx] = y_pred.squeeze()\n", + "\n", + "x_slice = _y[slice_idx[0], :, :]\n", + "y_slice = _y[:, slice_idx[1], :]\n", + "z_slice = _y[:, :, slice_idx[2]]\n", + "slices = [x_slice, y_slice, z_slice]\n", + "\n", + "fig, axes = plt.subplots(1, len(slices))\n", + "for i, _slice in enumerate(slices):\n", + " axes[i].imshow(_slice.T, cmap=\"gray\", origin=\"lower\", aspect='equal')\n", + " axes[i].set_axis_off()\n", + "\n", + "plt.show()" + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "code", + "source": "", + "id": "fae657ba6d3734a4", + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}