Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
update model
Browse files Browse the repository at this point in the history
  • Loading branch information
josephmje committed Dec 3, 2021
1 parent 46ee4f5 commit 1ac4ebf
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 210 deletions.
341 changes: 154 additions & 187 deletions docs/notebooks/Testing GP model.ipynb

Large diffs are not rendered by default.

45 changes: 22 additions & 23 deletions eddymotion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,28 +334,6 @@ def predict(self, gradient, **kwargs):
return retval


def _rasb2dipy(gradient):
gradient = np.asanyarray(gradient)
if gradient.ndim == 1:
if gradient.size != 4:
raise ValueError("Missing gradient information.")
gradient = gradient[..., np.newaxis]

if gradient.shape[0] != 4:
gradient = gradient.T
elif gradient.shape == (4, 4):
print("Warning: make sure gradient information is not transposed!")

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
retval = gradient_table(gradient[3, :], gradient[:3, :].T)
return retval


def _model_fit(model, data):
return model.fit(data)


class SparseFascicleModel:
"""
A wrapper of :obj:`dipy.reconst.sfm.SparseFascicleModel.
Expand All @@ -366,7 +344,6 @@ class SparseFascicleModel:
def __init__(self, gtab, S0=None, mask=None, solver=None, **kwargs):
"""Instantiate the wrapped model."""
from dipy.reconst.sfm import SparseFascicleModel
from sklearn.gaussian_process import GaussianProcessRegressor

self._S0 = None
if S0 is not None:
Expand Down Expand Up @@ -408,3 +385,25 @@ def predict(self, gradient, **kwargs):
retval = np.zeros_like(self._mask, dtype="float32")
retval[self._mask, ...] = predicted
return retval


def _rasb2dipy(gradient):
gradient = np.asanyarray(gradient)
if gradient.ndim == 1:
if gradient.size != 4:
raise ValueError("Missing gradient information.")
gradient = gradient[..., np.newaxis]

if gradient.shape[0] != 4:
gradient = gradient.T
elif gradient.shape == (4, 4):
print("Warning: make sure gradient information is not transposed!")

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
retval = gradient_table(gradient[3, :], gradient[:3, :].T)
return retval


def _model_fit(model, data):
return model.fit(data)

0 comments on commit 1ac4ebf

Please sign in to comment.