Skip to content

Commit

Permalink
ENH: Implement Gaussian Process
Browse files Browse the repository at this point in the history
Implement Gaussian Process.
  • Loading branch information
jhlegarreta committed May 16, 2024
1 parent be66f4d commit 2234d06
Show file tree
Hide file tree
Showing 8 changed files with 801 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"numpy>=1.17.3",
"nest-asyncio>=1.5.1",
"scikit-image>=0.14.2",
"scikit_learn",
"scipy>=1.8.0",
]
dynamic = ["version"]
Expand Down
1 change: 1 addition & 0 deletions src/eddymotion/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def estimate(
"avg",
"average",
"mean",
"gp",
) or model.lower().startswith("full")

dwmodel = None
Expand Down
2 changes: 2 additions & 0 deletions src/eddymotion/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
AverageDWModel,
DKIModel,
DTIModel,
GaussianProcessModel,
ModelFactory,
PETModel,
TrivialB0Model,
Expand All @@ -36,6 +37,7 @@
"AverageDWModel",
"DKIModel",
"DTIModel",
"GaussianProcessModel",
"TrivialB0Model",
"PETModel",
)
105 changes: 105 additions & 0 deletions src/eddymotion/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import warnings

import numpy as np
import sklearn
from dipy.core.gradients import gradient_table
from joblib import Parallel, delayed

Expand Down Expand Up @@ -127,6 +128,8 @@ def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs):
if not model_str:
raise TypeError("No model defined")

# ToDo
# Use lazy loading ?
from importlib import import_module

module_name, class_name = model_str.rsplit(".", 1)
Expand Down Expand Up @@ -409,6 +412,108 @@ class DKIModel(BaseModel):
_model_class = "dipy.reconst.dki.DiffusionKurtosisModel"


class GaussianProcessModel:
"""A Gaussian Process model based on [Andersson16a]_ (fig 1).
DWIs need to be transformed to a single ref space (fig 2 [Andersson16b]_ ?)
Definitions:
s: reference/undistorted space: used to denote the space or any image in
that space
f: observed/distorted image: used to denote any image in acquisition space
a: acquisition parameters: PE‐direction and bandwidth in PE‐direction
r: rigid body (subject movement) parameters
\beta: Eddy current parameters
e(\beta): Eddy current‐induced off resonance field (Hz)
h: Susceptibility induced off‐resonance field (Hz)
(fig 1) and algorithm:
1. Input: N DWI volumes f_{i} with acq parameters a_{i}; susceptibility
field h
2. Initialize: set all beta_{i} and r_i{i} = 0
3. Compute for M iterations
- Load GP prediction maker
- For all i in N (DWIs)
- Compute \\hat{s}_{i} (f_{i}, h, \beta_{i}, r_{i}, a_{i}) eqs 2 and 4
- Load \\hat{s}_{i} (f_{i}, h, \beta_{i}, r_{i}, a_{i}) as training
data for GP
- Estimate hyperparameters for the GP used to predict the signal shape
for every voxel
- Update EC and movement parameters
- For all i in N (DWIs)
- Draw a prediction s_{i} from the GP
- Compute \\hat{f}_{i} (s_{i}, h, \beta_{i}, r_{i}, a_{i})
- Use \\hat{f}_{i} - f_{i} to update \beta_{i} and r_{i} (eq 6)
a: direction of the PE and the total readout time (here defined as the time
between the acquisition of the center of the first and last echoes).
Internally, a is divided into a = [p t] where p is a unity length 1 x 3
vector defining the PE direction (such that for example [1 0 0], [−1 0 0],
[0 1 0] and [0 −1 0] denote R → L, L → R, P → A and A P PE direction
respectively) and where t denotes the readout time (in seconds).
r: 1x6 vector: 3 translations, 3 rotations
\beta: four for linear; ten for quadratic, twenty for cubic.
h: assumed to be in the same space as the first b = 0 image supplied to
eddy, which will be automatically fulfilled if it was estimated by topup
and that same b = 0 image was the first of those supplied to topup. Hence,
it can be said to help define the reference/undistorted space as the first
b = 0 image after distortion correction by h.
See Appendix A for further details.
Add the outlier detection part in [Andersson16b]?
References
----------
.. [Andersson16a] J. L. R. Andersson. et al., An integrated approach to
correction for off-resonance effects and subject movement in diffusion MR
imaging, NeuroImage 125 (2016) 1063–1078
.. [Andersson16b] J. L. R. Andersson. et al., Incorporating outlier
detection and replacement into a non-parametric framework for movement and
distortion correction of diffusion MR images, NeuroImage 141 (2016) 556–572
"""

__slots__ = ("_dwi", "_a", "_h", "_kernel", "_num_iterations", "_betas", "_r", "_gpr")

def __init__(self, dwi, a, h, kernel, num_iterations=5, **kwargs):
"""Implement object initialization."""

# ToDo
# This should be he HDF5 file dwi object, so that we avoid having the
# entire 4D volume in memory
self._dwi = dwi
self._a = a
self._h = h
self._num_iterations = num_iterations

# Initialize
self._betas = 0
self._r = 0

# ToDo
# Build the GP kernel here or in fit ?
# self._gpr = None
# Does the kernel depend on which data we use as the training data (i.e.
# varies with the index we choose to predict)?
self._kernel = kernel

def fit(self, *args, **kwargs):
"""The x are our gradient directions; the observations are our diffusion
volumes.
X_train: array-like of shape (n_samples, n_features), n_samples being
the number of gradients, and the n_features the number of shells ?
y_train_array-like of shape (n_samples,) or (n_samples, n_targets)"""

self._gpr = sklearn.gaussian_process.GaussianProcessRegressor(kernel=self._kernel)
self._gpr.fit(X_train, y_train)

def predict(self, gradient, **kwargs):
"""Return the Gaussian Process prediction according to [Andersson16]_"""
# ToDo
# Call self._gprlog_marginal_likelihood for eq. 12 in Andersson 15 ?
y_mean, y_std = self._gpr.predict(X, return_std=True)
return y_mean, y_std


def _rasb2dipy(gradient):
gradient = np.asanyarray(gradient)
if gradient.ndim == 1:
Expand Down
Empty file.
170 changes: 170 additions & 0 deletions src/eddymotion/model/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2022 The NiPreps Developers <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
import nibabel as nib
import numpy as np
from eddymotion.model.utils import (
extract_dmri_shell,
find_shelling_scheme,
is_positive_definite,
# update_covariance1,
# update_covariance2,
)


def test_is_positive_definite():

matrix = np.array([[4, 1, 2], [1, 3, 1], [2, 1, 5]])
assert is_positive_definite(matrix)

matrix = np.array([[4, 1, 2], [1, -3, 1], [2, 1, 5]])
assert not is_positive_definite(matrix)


def test_update_covariance():

_K = np.random.rand(5, 5)
_thpar = [0.5, 1.0, 2.0]
update_covariance1(_K, _thpar)
print(_K) # Updated covariance matrix


def test_extract_dmri_shell():

# dMRI volume with 5 gradients
bvals = np.asarray([0, 1980, 12, 990, 2000])
bval_count = len(bvals)
vols_size = (10, 15, 20)
dwi = np.ones((*vols_size, bval_count))
bvecs = np.ones((bval_count, 3))
# Set all i-th gradient dMRI volume data and bvecs values to i
for i in range(bval_count):
dwi[..., i] = i
bvecs[i, :] = i
dwi_img = nib.Nifti1Image(dwi, affine=np.eye(4))

bvals_to_extract = [0, 2000]
tol = 15

expected_indices = np.asarray([0, 2, 4])
expected_shell_data = np.stack([i*np.ones(vols_size) for i in expected_indices], axis=-1)
expected_shell_bvals = np.asarray([0, 12, 2000])
expected_shell_bvecs = np.asarray([[i]*3 for i in expected_indices])

(
obtained_indices,
obtained_shell_data,
obtained_shell_bvals,
obtained_shell_bvecs
) = extract_dmri_shell(
dwi_img, bvals, bvecs, bvals_to_extract=bvals_to_extract, tol=tol)

assert np.array_equal(obtained_indices, expected_indices)
assert np.array_equal(obtained_shell_data, expected_shell_data)
assert np.array_equal(obtained_shell_bvals, expected_shell_bvals)
assert np.array_equal(obtained_shell_bvecs, expected_shell_bvecs)

bvals = np.asarray([0, 1010, 12, 990, 2000])
bval_count = len(bvals)
vols_size = (10, 15, 20)
dwi = np.ones((*vols_size, bval_count))
bvecs = np.ones((bval_count, 3))
# Set all i-th gradient dMRI volume data and bvecs values to i
for i in range(bval_count):
dwi[..., i] = i
bvecs[i, :] = i
dwi_img = nib.Nifti1Image(dwi, affine=np.eye(4))

bvals_to_extract = [0, 1000]
tol = 20

expected_indices = np.asarray([0, 1, 2, 3])
expected_shell_data = np.stack([i*np.ones(vols_size) for i in expected_indices], axis=-1)
expected_shell_bvals = np.asarray([0, 1010, 12, 990])
expected_shell_bvecs = np.asarray([[i]*3 for i in expected_indices])

(
obtained_indices,
obtained_shell_data,
obtained_shell_bvals,
obtained_shell_bvecs
) = extract_dmri_shell(
dwi_img, bvals, bvecs, bvals_to_extract=bvals_to_extract, tol=tol)

assert np.array_equal(obtained_indices, expected_indices)
assert np.array_equal(obtained_shell_data, expected_shell_data)
assert np.array_equal(obtained_shell_bvals, expected_shell_bvals)
assert np.array_equal(obtained_shell_bvecs, expected_shell_bvecs)


def test_find_shelling_scheme():

tol = 20
bvals = np.asarray([0, 0])
expected_shells = np.asarray([0])
expected_bval_centroids = np.asarray([0, 0])
obtained_shells, obtained_bval_centroids = find_shelling_scheme(
bvals, tol=tol)

assert np.array_equal(obtained_shells, expected_shells)
assert np.array_equal(obtained_bval_centroids, expected_bval_centroids)

bvals = np.asarray([
5, 300, 300, 300, 300, 300, 305, 1005, 995, 1000, 1000, 1005, 1000,
1000, 1005, 995, 1000, 1005, 5, 995, 1000, 1000, 995, 1005, 995, 1000,
995, 995, 2005, 2000, 2005, 2005, 1995, 2000, 2005, 2000, 1995, 2005, 5,
1995, 2005, 1995, 1995, 2005, 2005, 1995, 2000, 2000, 2000, 1995, 2000, 2000,
2005, 2005, 1995, 2005, 2005, 1990, 1995, 1995, 1995, 2005, 2000, 1990, 2010, 5
])
expected_shells = np.asarray([5., 300.83333333, 999.5, 2000.])
expected_bval_centroids = ([
5., 300.83333333, 300.83333333, 300.83333333, 300.83333333, 300.83333333, 300.83333333, 999.5, 999.5, 999.5, 999.5, 999.5, 999.5,
999.5, 999.5, 999.5, 999.5, 999.5, 5., 999.5, 999.5, 999.5, 999.5, 999.5, 999.5, 999.5,
999.5, 999.5, 2000., 2000., 2000., 2000., 2000., 2000., 2000., 2000., 2000., 2000., 5.,
2000., 2000., 2000., 2000., 2000., 2000., 2000., 2000., 2000., 2000., 2000., 2000., 2000.,
2000., 2000., 2000., 2000., 2000., 2000., 2000., 2000., 2000., 2000., 2000., 2000., 2000., 5.
])
obtained_shells, obtained_bval_centroids = find_shelling_scheme(
bvals, tol=tol)

# ToDo
# Giving a tolerance of 15 this fails because it finds 5 clusters
assert np.allclose(obtained_shells, expected_shells)
assert np.allclose(obtained_bval_centroids, expected_bval_centroids)

bvals = np.asarray([0, 1980, 12, 990, 2000])
expected_shells = np.asarray([6, 990, 1980, 2000])
expected_bval_centroids = np.asarray([6, 1980, 6, 990, 2000])
obtained_shells, obtained_bval_centroids = find_shelling_scheme(
bvals, tol=tol)

assert np.allclose(obtained_shells, expected_shells)
assert np.allclose(obtained_bval_centroids, expected_bval_centroids)

bvals = np.asarray([0, 1010, 12, 990, 2000])
tol = 60
expected_shells = np.asarray([6, 1000, 2000])
expected_bval_centroids = np.asarray([6, 1000, 6, 1000, 2000])
obtained_shells, obtained_bval_centroids = find_shelling_scheme(bvals, tol)

assert np.allclose(obtained_shells, expected_shells)
assert np.allclose(obtained_bval_centroids, expected_bval_centroids)
Loading

0 comments on commit 2234d06

Please sign in to comment.