Skip to content

Commit

Permalink
Merge pull request nipreps#60 from josephmje/enh/test_gp_model
Browse files Browse the repository at this point in the history
ENH: Adds Sparse Fascicle and Gaussian Process models
  • Loading branch information
dPys authored Dec 10, 2021
2 parents 503a021 + b232df6 commit 2c122a4
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 11 deletions.
76 changes: 66 additions & 10 deletions eddymotion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import nest_asyncio

import numpy as np
from dipy.core.gradients import gradient_table
from dipy.core.gradients import check_multi_b, gradient_table

nest_asyncio.apply()

Expand All @@ -25,7 +25,7 @@ def init(gtab, model="DTI", **kwargs):
An array representing the gradient table in RAS+B format.
model : :obj:`str`
Diffusion model.
Options: ``"3DShore"``, ``"SFM"``, ``"DTI"``, ``"DKI"``, ``"S0"``
Options: ``"3DShore"``, ``"SFM"``, ``"GP"``, ``"DTI"``, ``"DKI"``, ``"S0"``
Return
------
Expand Down Expand Up @@ -53,15 +53,18 @@ def init(gtab, model="DTI", **kwargs):
"lambdaL": 1e-8,
}

elif model.lower().startswith("sfm"):
from eddymotion.utils.model import (
SFM4HMC as Model,
ExponentialIsotropicModel,
)
elif model.lower() in ("sfm", "gp"):
Model = SparseFascicleModel
param = {"solver": "ElasticNet"}

param = {
"isotropic": ExponentialIsotropicModel,
}
if model.lower() == "gp":
from sklearn.gaussian_process import GaussianProcessRegressor
param = {"solver": GaussianProcessRegressor}

multi_b = check_multi_b(gtab, 2, non_zero=False)
if multi_b:
from dipy.reconst.sfm import ExponentialIsotropicModel
param.update({"isotropic": ExponentialIsotropicModel})

elif model.lower() in ("dti", "dki"):
Model = DTIModel if model.lower() == "dti" else DKIModel
Expand Down Expand Up @@ -332,6 +335,59 @@ def predict(self, gradient, **kwargs):
return retval


class SparseFascicleModel:
"""
A wrapper of :obj:`dipy.reconst.sfm.SparseFascicleModel.
"""

__slots__ = ("_model", "_S0", "_mask", "_solver")

def __init__(self, gtab, S0=None, mask=None, solver=None, **kwargs):
"""Instantiate the wrapped model."""
from dipy.reconst.sfm import SparseFascicleModel

self._S0 = None
if S0 is not None:
self._S0 = np.clip(
S0.astype("float32") / S0.max(),
a_min=1e-5,
a_max=1.0,
)

self._mask = mask
if mask is None and S0 is not None:
self._mask = self._S0 > np.percentile(self._S0, 35)

if self._mask is not None:
self._S0 = self._S0[self._mask.astype(bool)]

self._solver = solver
if solver is None:
self._solver = "ElasticNet"

kwargs = {k: v for k, v in kwargs.items() if k in ("solver",)}
self._model = SparseFascicleModel(gtab, **kwargs)

def fit(self, data, **kwargs):
"""Clean-up permitted args and kwargs, and call model's fit."""
self._model = self._model.fit(data[self._mask, ...])

def predict(self, gradient, **kwargs):
"""Propagate model parameters and call predict."""
predicted = np.squeeze(
self._model.predict(
_rasb2dipy(gradient),
S0=self._S0,
)
)
if predicted.ndim == 3:
return predicted

retval = np.zeros_like(self._mask, dtype="float32")
retval[self._mask, ...] = predicted
return retval


def _rasb2dipy(gradient):
gradient = np.asanyarray(gradient)
if gradient.ndim == 1:
Expand Down
5 changes: 4 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ url = https://github.com/nipreps/EddyMotionCorrection
python_requires = >=3.7
install_requires =
dipy>=1.3.0
scikit-image>=0.14.2
nipype>= 1.5.1, < 2.0
nitransforms>=21.0.0
nest-asyncio>=1.5.1
scikit-image>=0.14.2
scikit-learn>=1.0.1
test_requires =
codecov
coverage
Expand Down

0 comments on commit 2c122a4

Please sign in to comment.