Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hyperparameters optimization #2

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions docs/notebooks/dwi_gp_representation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "Gaussian process notebook",
"id": "486923b289155658"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-05-27T22:08:12.924224Z",
"start_time": "2024-05-27T22:07:46.892722Z"
}
},
"cell_type": "code",
"source": [
"import tempfile\n",
"from pathlib import Path\n",
"\n",
"import numpy as np\n",
"from sklearn.gaussian_process.kernels import DotProduct, WhiteKernel\n",
"\n",
"from eddymotion import model\n",
"from eddymotion.data.dmri import DWI\n",
"from eddymotion.data.splitting import lovo_split\n",
"\n",
"datadir = Path(\"../../test\") # Adapt to your local path or download to a temp location using wget\n",
"\n",
"kernel = DotProduct() + WhiteKernel()\n",
"\n",
"dwi = DWI.from_filename(datadir / \"dwi.h5\")\n",
"\n",
"_dwi_data = dwi.dataobj\n",
"# Use a subset of the data for now to see that something is written to the\n",
"# output\n",
"# bvecs = dwi.gradients[:3, :].T\n",
"bvecs = dwi.gradients[:3, 10:13].T # b0 values have already been masked\n",
"# bvals = dwi.gradients[3:, 10:13].T # Only for inspection purposes: [[1005.], [1000.], [ 995.]]\n",
"dwi_data = _dwi_data[60:63, 60:64, 40:45, 10:13]\n",
"\n",
"# ToDo\n",
"# Provide proper values/estimates for these\n",
"a = 1\n",
"h = 1 # should be a NIfTI image\n",
"\n",
"num_iterations = 5\n",
"gp = model.GaussianProcessModel(\n",
" dwi=dwi, a=a, h=h, kernel=kernel, num_iterations=num_iterations\n",
")\n",
"indices = list(range(bvecs.shape[0]))\n",
"# ToDo\n",
"# This should be done within the GP model class\n",
"# Apply lovo strategy properly\n",
"# Vectorize and parallelize\n",
"result = np.zeros_like(dwi_data)\n",
"for idx in indices:\n",
" lovo_idx = np.ones(len(indices), dtype=bool)\n",
" lovo_idx[idx] = False\n",
" X = bvecs[lovo_idx]\n",
" for i in range(dwi_data.shape[0]):\n",
" for j in range(dwi_data.shape[1]):\n",
" for k in range(dwi_data.shape[2]):\n",
" # ToDo\n",
" # Use a mask to avoid traversing background data\n",
" y = dwi_data[i, j, k, lovo_idx]\n",
" gp.fit(X, y)\n",
" prediction, _ = gp.predict(\n",
" bvecs[idx, :][np.newaxis]\n",
" ) # Can take multiple values X[:2, :]\n",
" result[i, j, k, idx] = prediction.item()"
],
"id": "da2274009534db61",
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
"Cell \u001B[0;32mIn[3], line 21\u001B[0m\n\u001B[1;32m 17\u001B[0m _dwi_data \u001B[38;5;241m=\u001B[39m dwi\u001B[38;5;241m.\u001B[39mdataobj\n\u001B[1;32m 18\u001B[0m \u001B[38;5;66;03m# Use a subset of the data for now to see that something is written to the\u001B[39;00m\n\u001B[1;32m 19\u001B[0m \u001B[38;5;66;03m# output\u001B[39;00m\n\u001B[1;32m 20\u001B[0m \u001B[38;5;66;03m# bvecs = dwi.gradients[:3, :].T\u001B[39;00m\n\u001B[0;32m---> 21\u001B[0m bvecs \u001B[38;5;241m=\u001B[39m \u001B[43mdwi\u001B[49m\u001B[38;5;241m.\u001B[39mgradients[:\u001B[38;5;241m3\u001B[39m, \u001B[38;5;241m10\u001B[39m:\u001B[38;5;241m13\u001B[39m]\u001B[38;5;241m.\u001B[39mT \u001B[38;5;66;03m# b0 values have already been masked\u001B[39;00m\n\u001B[1;32m 22\u001B[0m \u001B[38;5;66;03m# bvals = dwi.gradients[3:, 10:13].T # Only for inspection purposes: [[1005.], [1000.], [ 995.]]\u001B[39;00m\n\u001B[1;32m 23\u001B[0m dwi_data \u001B[38;5;241m=\u001B[39m _dwi_data[\u001B[38;5;241m60\u001B[39m:\u001B[38;5;241m63\u001B[39m, \u001B[38;5;241m60\u001B[39m:\u001B[38;5;241m64\u001B[39m, \u001B[38;5;241m40\u001B[39m:\u001B[38;5;241m45\u001B[39m, \u001B[38;5;241m10\u001B[39m:\u001B[38;5;241m13\u001B[39m]\n",
"Cell \u001B[0;32mIn[3], line 21\u001B[0m\n\u001B[1;32m 17\u001B[0m _dwi_data \u001B[38;5;241m=\u001B[39m dwi\u001B[38;5;241m.\u001B[39mdataobj\n\u001B[1;32m 18\u001B[0m \u001B[38;5;66;03m# Use a subset of the data for now to see that something is written to the\u001B[39;00m\n\u001B[1;32m 19\u001B[0m \u001B[38;5;66;03m# output\u001B[39;00m\n\u001B[1;32m 20\u001B[0m \u001B[38;5;66;03m# bvecs = dwi.gradients[:3, :].T\u001B[39;00m\n\u001B[0;32m---> 21\u001B[0m bvecs \u001B[38;5;241m=\u001B[39m \u001B[43mdwi\u001B[49m\u001B[38;5;241m.\u001B[39mgradients[:\u001B[38;5;241m3\u001B[39m, \u001B[38;5;241m10\u001B[39m:\u001B[38;5;241m13\u001B[39m]\u001B[38;5;241m.\u001B[39mT \u001B[38;5;66;03m# b0 values have already been masked\u001B[39;00m\n\u001B[1;32m 22\u001B[0m \u001B[38;5;66;03m# bvals = dwi.gradients[3:, 10:13].T # Only for inspection purposes: [[1005.], [1000.], [ 995.]]\u001B[39;00m\n\u001B[1;32m 23\u001B[0m dwi_data \u001B[38;5;241m=\u001B[39m _dwi_data[\u001B[38;5;241m60\u001B[39m:\u001B[38;5;241m63\u001B[39m, \u001B[38;5;241m60\u001B[39m:\u001B[38;5;241m64\u001B[39m, \u001B[38;5;241m40\u001B[39m:\u001B[38;5;241m45\u001B[39m, \u001B[38;5;241m10\u001B[39m:\u001B[38;5;241m13\u001B[39m]\n",
"File \u001B[0;32m/snap/pycharm-professional/387/plugins/python/helpers/pydev/_pydevd_bundle/pydevd_frame.py:888\u001B[0m, in \u001B[0;36mPyDBFrame.trace_dispatch\u001B[0;34m(self, frame, event, arg)\u001B[0m\n\u001B[1;32m 885\u001B[0m stop \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mFalse\u001B[39;00m\n\u001B[1;32m 887\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m plugin_stop:\n\u001B[0;32m--> 888\u001B[0m stopped_on_plugin \u001B[38;5;241m=\u001B[39m \u001B[43mplugin_manager\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mstop\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmain_debugger\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mframe\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mevent\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_args\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mstop_info\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43marg\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mstep_cmd\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 889\u001B[0m \u001B[38;5;28;01melif\u001B[39;00m stop:\n\u001B[1;32m 890\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m is_line:\n",
"File \u001B[0;32m/snap/pycharm-professional/387/plugins/python/helpers-pro/jupyter_debug/pydev_jupyter_plugin.py:169\u001B[0m, in \u001B[0;36mstop\u001B[0;34m(plugin, pydb, frame, event, args, stop_info, arg, step_cmd)\u001B[0m\n\u001B[1;32m 167\u001B[0m frame \u001B[38;5;241m=\u001B[39m suspend_jupyter(main_debugger, thread, frame, step_cmd)\n\u001B[1;32m 168\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m frame:\n\u001B[0;32m--> 169\u001B[0m \u001B[43mmain_debugger\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdo_wait_suspend\u001B[49m\u001B[43m(\u001B[49m\u001B[43mthread\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mframe\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mevent\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43marg\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 170\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[1;32m 171\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;01mFalse\u001B[39;00m\n",
"File \u001B[0;32m/snap/pycharm-professional/387/plugins/python/helpers/pydev/pydevd.py:1185\u001B[0m, in \u001B[0;36mPyDB.do_wait_suspend\u001B[0;34m(self, thread, frame, event, arg, send_suspend_message, is_unhandled_exception)\u001B[0m\n\u001B[1;32m 1182\u001B[0m from_this_thread\u001B[38;5;241m.\u001B[39mappend(frame_id)\n\u001B[1;32m 1184\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_threads_suspended_single_notification\u001B[38;5;241m.\u001B[39mnotify_thread_suspended(thread_id, stop_reason):\n\u001B[0;32m-> 1185\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_do_wait_suspend\u001B[49m\u001B[43m(\u001B[49m\u001B[43mthread\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mframe\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mevent\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43marg\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msuspend_type\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mfrom_this_thread\u001B[49m\u001B[43m)\u001B[49m\n",
"File \u001B[0;32m/snap/pycharm-professional/387/plugins/python/helpers/pydev/pydevd.py:1200\u001B[0m, in \u001B[0;36mPyDB._do_wait_suspend\u001B[0;34m(self, thread, frame, event, arg, suspend_type, from_this_thread)\u001B[0m\n\u001B[1;32m 1197\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_call_mpl_hook()\n\u001B[1;32m 1199\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mprocess_internal_commands()\n\u001B[0;32m-> 1200\u001B[0m \u001B[43mtime\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msleep\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m0.01\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1202\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcancel_async_evaluation(get_current_thread_id(thread), \u001B[38;5;28mstr\u001B[39m(\u001B[38;5;28mid\u001B[39m(frame)))\n\u001B[1;32m 1204\u001B[0m \u001B[38;5;66;03m# process any stepping instructions\u001B[39;00m\n",
"\u001B[0;31mKeyboardInterrupt\u001B[0m: "
]
}
],
"execution_count": 3
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Plot the data",
"id": "77e77cd4c73409d3"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"from matplotlib import pyplot as plt \n",
"%matplotlib inline\n",
"\n",
"s = dwi_data[1, 1, 2, :]\n",
"s_hat = result[1, 1, 2, :]\n",
"x = indices\n",
"\n",
"fig, ax = plt.subplots()\n",
"ax.plot(x, np.c_[s, s_hat], label=[\"ground truth\", \"predicted\"])\n",
"ax.set_xlabel(\"bvec indices\")\n",
"ax.set_ylabel(\"signal\")\n",
"ax.legend()\n",
"\n",
"plt.show()"
],
"id": "4e51f22890fb045a"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Plot the DWI signal brain data",
"id": "46bb0987d58eb394"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "# plot_dwi(dmri_dataset.dataobj, dmri_dataset.affine, gradient=data_test[1], black_bg=True)",
"id": "edb0e9d255516e38"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Plot the predicted DWI signal",
"id": "1a52e2450fc61dc6"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "# plot_dwi(predicted, dmri_dataset.affine, gradient=data_test[1], black_bg=True);",
"id": "66150cf337b395e0"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"numpy>=1.17.3",
"nest-asyncio>=1.5.1",
"scikit-image>=0.14.2",
"scikit_learn",
"scipy>=1.8.0",
]
dynamic = ["version"]
Expand Down
1 change: 1 addition & 0 deletions src/eddymotion/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def estimate(
"avg",
"average",
"mean",
"gp",
) or model.lower().startswith("full")

dwmodel = None
Expand Down
2 changes: 2 additions & 0 deletions src/eddymotion/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
AverageDWModel,
DKIModel,
DTIModel,
GaussianProcessModel,
ModelFactory,
PETModel,
TrivialB0Model,
Expand All @@ -36,6 +37,7 @@
"AverageDWModel",
"DKIModel",
"DTIModel",
"GaussianProcessModel",
"TrivialB0Model",
"PETModel",
)
156 changes: 156 additions & 0 deletions src/eddymotion/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
import numpy as np
from dipy.core.gradients import gradient_table
from joblib import Parallel, delayed
from sklearn.gaussian_process import GaussianProcessRegressor
from scipy.optimize import minimize, Bounds

from .utils import calculate_angle_matrix, stochastic_optimization_with_early_stopping,loo_cross_validation
from .kernels import SphericalCovarianceKernel


def _exec_fit(model, data, chunk=None):
Expand Down Expand Up @@ -127,6 +132,8 @@ def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs):
if not model_str:
raise TypeError("No model defined")

# ToDo
# Use lazy loading ?
from importlib import import_module

module_name, class_name = model_str.rsplit(".", 1)
Expand Down Expand Up @@ -239,6 +246,7 @@ def __init__(self, **kwargs):
r"""
Implement object initialization.

import sklearn
Parameters
----------
gtab : :obj:`~numpy.ndarray`
Expand Down Expand Up @@ -409,6 +417,154 @@ class DKIModel(BaseModel):
_model_class = "dipy.reconst.dki.DiffusionKurtosisModel"


class GaussianProcessModel:
"""A Gaussian Process model based on [Andersson16a]_ (fig 1).
DWIs need to be transformed to a single ref space (fig 2 [Andersson16b]_ ?)

Definitions:
s: reference/undistorted space: used to denote the space or any image in
that space
f: observed/distorted image: used to denote any image in acquisition space
a: acquisition parameters: PE‐direction and bandwidth in PE‐direction
r: rigid body (subject movement) parameters
\beta: Eddy current parameters
e(\beta): Eddy current‐induced off resonance field (Hz)
h: Susceptibility induced off‐resonance field (Hz)

(fig 1) and algorithm:
1. Input: N DWI volumes f_{i} with acq parameters a_{i}; susceptibility
field h
2. Initialize: set all beta_{i} and r_i{i} = 0
3. Compute for M iterations
- Load GP prediction maker
- For all i in N (DWIs)
- Compute \\hat{s}_{i} (f_{i}, h, \beta_{i}, r_{i}, a_{i}) eqs 2 and 4
- Load \\hat{s}_{i} (f_{i}, h, \beta_{i}, r_{i}, a_{i}) as training
data for GP
- Estimate hyperparameters for the GP used to predict the signal shape
for every voxel
- Update EC and movement parameters
- For all i in N (DWIs)
- Draw a prediction s_{i} from the GP
- Compute \\hat{f}_{i} (s_{i}, h, \beta_{i}, r_{i}, a_{i})
- Use \\hat{f}_{i} - f_{i} to update \beta_{i} and r_{i} (eq 6)

a: direction of the PE and the total readout time (here defined as the time
between the acquisition of the center of the first and last echoes).
Internally, a is divided into a = [p t] where p is a unity length 1 x 3
vector defining the PE direction (such that for example [1 0 0], [−1 0 0],
[0 1 0] and [0 −1 0] denote R → L, L → R, P → A and A P PE direction
respectively) and where t denotes the readout time (in seconds).
r: 1x6 vector: 3 translations, 3 rotations
\beta: four for linear; ten for quadratic, twenty for cubic.
h: assumed to be in the same space as the first b = 0 image supplied to
eddy, which will be automatically fulfilled if it was estimated by topup
and that same b = 0 image was the first of those supplied to topup. Hence,
it can be said to help define the reference/undistorted space as the first
b = 0 image after distortion correction by h.

See Appendix A for further details.

Add the outlier detection part in [Andersson16b]?

References
----------
.. [Andersson16a] J. L. R. Andersson. et al., An integrated approach to
correction for off-resonance effects and subject movement in diffusion MR
imaging, NeuroImage 125 (2016) 1063–1078
.. [Andersson16b] J. L. R. Andersson. et al., Incorporating outlier
detection and replacement into a non-parametric framework for movement and
distortion correction of diffusion MR images, NeuroImage 141 (2016) 556–572
"""

__slots__ = (
"_dwi",
"_a",
"_h",
"_kernel",
"_num_iterations",
"_betas",
"_r",
"_gpr",
"_model",
)

def __init__(self, dwi, a, h, kernel, num_iterations=5, **kwargs):
"""Implement object initialization."""
self._dwi = dwi
self._a = a
self._h = h
self._num_iterations = num_iterations
self._betas = 0
self._r = 0
self._kernel = kernel

def fit(self, gradient_directions, data, initial_beta, batch_size=1000, max_iter=1000, tolerance=1e-4, patience=20, hp_opti_method='LOO'):
"""
Fit the Gaussian Process model to the training data.

Parameters
----------
gradient_directions : array-like of shape (num_directions, 3)
Training data (gradient directions).
data : array-like of shape (num_directions, num_voxels)
Target values (diffusion volumes).
initial_beta : array-like of shape (3,)
Initial guess for the log-transformed hyperparameters.
batch_size : int, default=1000
Size of the mini-batches for optimization.
max_iter : int, default=1000
Maximum number of iterations.
tolerance : float, default=1e-4
Tolerance for improvement in loss.
patience : int, default=20
Patience for early stopping.
hp_opti_method : str, default='LOO'
Method for hyperparameter optimization. Options are 'LOO' for Leave-One-Out cross-validation
and 'MML' for Marginal Maximum Likelihood with stochastic optimization.
"""
# bounds estimation
lambda_upper = 1e4
lambda_lower = 1e-6
a_upper = np.pi
a_lower = 1e-6
sigma_sq_upper = 1e4
sigma_sq_lower = 1e-6

bounds = Bounds(np.log([lambda_lower, a_lower, sigma_sq_lower]), np.log([lambda_upper, a_upper, sigma_sq_upper]))

# Compute angles from the gradient directions
angles = calculate_angle_matrix(gradient_directions.T)

voxel_intensities_flatten = data.reshape(len(angles), -1)
reshaped_angles = angles.reshape(-1, 1)

if hp_opti_method == 'LOO':
result = minimize(loo_cross_validation, initial_beta, args=(voxel_intensities_flatten, angles), bounds=bounds, method='L-BFGS-B', options={'maxiter':10})
optimal_beta = result.x

elif hp_opti_method == 'MML':

optimal_beta = stochastic_optimization_with_early_stopping(initial_beta, voxel_intensities_flatten, angles, batch_size, bounds, max_iter, patience, tolerance)

else:
raise NotImplementedError(f"Hyperparameter optimization method {hp_opti_method} is not implemented")

optimal_lambda, optimal_a, optimal_sigma_sq = np.exp(optimal_beta)

# Update the kernel with the optimized hyperparameters
self._kernel.set_params(lambda_=optimal_lambda, a=optimal_a, sigma_sq=optimal_sigma_sq)

# Fit the Gaussian Process Regressor with the optimized kernel
self._gpr = GaussianProcessRegressor(kernel=self._kernel)
self._gpr.fit(angles, voxel_intensities_flatten)

def predict(self, angles, **kwargs):
"""Return the Gaussian Process prediction according to [Andersson16]_"""
y_mean, y_std = self._gpr.predict(angles, return_std=True)
return y_mean, y_std


def _rasb2dipy(gradient):
gradient = np.asanyarray(gradient)
if gradient.ndim == 1:
Expand Down
Loading