Skip to content
This repository has been archived by the owner on Dec 20, 2024. It is now read-only.

Commit

Permalink
Merge pull request #221 from nipreps/fix/generalize-models
Browse files Browse the repository at this point in the history
FIX: Generalized model structure
  • Loading branch information
oesteban authored Aug 29, 2024
2 parents 2a90ca6 + f2576d3 commit fcd9e94
Show file tree
Hide file tree
Showing 11 changed files with 1,204 additions and 1,141 deletions.
1 change: 1 addition & 0 deletions docs/developers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ Information on specific functions, classes, and methods.
api/eddymotion.exceptions
api/eddymotion.math
api/eddymotion.model
api/eddymotion.registration
api/eddymotion.utils
12 changes: 8 additions & 4 deletions src/eddymotion/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,25 @@
"""Data models."""

from eddymotion.model.base import (
AverageModel,
ModelFactory,
TrivialModel,
)
from eddymotion.model.dmri import (
AverageDWModel,
DKIModel,
DTIModel,
GPModel,
ModelFactory,
PETModel,
TrivialB0Model,
)
from eddymotion.model.pet import PETModel

__all__ = (
"ModelFactory",
"AverageModel",
"AverageDWModel",
"DKIModel",
"DTIModel",
"GPModel",
"TrivialB0Model",
"TrivialModel",
"PETModel",
)
21 changes: 20 additions & 1 deletion src/eddymotion/model/dipy.py → src/eddymotion/model/_dipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@

from __future__ import annotations

import warnings
from sys import modules

import numpy as np
from dipy.core.gradients import GradientTable
from dipy.core.gradients import GradientTable, gradient_table
from dipy.reconst.base import ReconstModel
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import (
Expand Down Expand Up @@ -690,3 +691,21 @@ def set_params(self, **params):
self.a = params.get("a", self.a)
self.sigma_sq = params.get("sigma_sq", self.sigma_sq)
return self


def _rasb2dipy(gradient):
gradient = np.asanyarray(gradient)
if gradient.ndim == 1:
if gradient.size != 4:
raise ValueError("Missing gradient information.")
gradient = gradient[..., np.newaxis]

if gradient.shape[0] != 4:
gradient = gradient.T
elif gradient.shape == (4, 4):
print("Warning: make sure gradient information is not transposed!")

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
retval = gradient_table(gradient[3, :], gradient[:3, :].T)
return retval
Loading

0 comments on commit fcd9e94

Please sign in to comment.