From 15e3a390cc84ab04a30e8392bd26128e6e3e6acb Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 29 Aug 2024 15:40:19 +0200 Subject: [PATCH 1/5] enh: generalize trivial and average models --- src/eddymotion/model/__init__.py | 4 +- src/eddymotion/model/base.py | 102 +++++++++++++++++++++++-------- 2 files changed, 79 insertions(+), 27 deletions(-) diff --git a/src/eddymotion/model/__init__.py b/src/eddymotion/model/__init__.py index ae454d40..103f9008 100644 --- a/src/eddymotion/model/__init__.py +++ b/src/eddymotion/model/__init__.py @@ -29,7 +29,7 @@ GPModel, ModelFactory, PETModel, - TrivialB0Model, + TrivialModel, ) __all__ = ( @@ -38,6 +38,6 @@ "DKIModel", "DTIModel", "GPModel", - "TrivialB0Model", + "TrivialModel", "PETModel", ) diff --git a/src/eddymotion/model/base.py b/src/eddymotion/model/base.py index 4da06061..fbaad99e 100644 --- a/src/eddymotion/model/base.py +++ b/src/eddymotion/model/base.py @@ -85,9 +85,12 @@ def init(model="DTI", **kwargs): if model.lower() in ("s0", "b0"): return TrivialB0Model(S0=kwargs.pop("S0"), gtab=kwargs.pop("gtab")) - if model.lower() in ("avg", "average", "mean"): + if model.lower() in ("avgdwi", "averagedwi", "meandwi"): return AverageDWModel(**kwargs) + if model.lower() in ("avg", "average", "mean"): + return AverageModel(**kwargs) + if model.lower() in ("dti", "dki", "pet"): Model = globals()[f"{model.upper()}Model"] return Model(**kwargs) @@ -247,6 +250,79 @@ def predict(self, index=None, **kwargs): retval = np.zeros(self._datashape, dtype="float32") retval[self._mask] = predicted return retval + +class TrivialModel(BaseModel): + """A trivial model that returns a given map always.""" + + __slots__ = ("_predicted", ) + + def __init__(self, predicted=None, **kwargs): + """Implement object initialization.""" + if predicted is None: + raise TypeError( + "This model requires the predicted map at initialization" + ) + + super().__init__(**kwargs) + self._predicted = predicted + self._datashape = predicted.shape + + @property + def is_fitted(self): + return True + + def fit(self, data, **kwargs): + """Do nothing.""" + + def predict(self, *_, **kwargs): + """Return the *b=0* map.""" + + # No need to check fit (if not fitted, has raised already) + return self._predicted + + +class AverageModel(BaseModel): + """A trivial model that returns an average map.""" + + __slots__ = ("_data", ) + + def __init__(self, **kwargs): + """Initialize a new model.""" + super().__init__(**kwargs) + self._data = None + + def fit(self, data, **kwargs): + """Calculate the average.""" + + # Regress out global signal differences + if kwargs.pop("equalize", False): + data = data.copy().astype('float32') + reshaped_data = ( + data.reshape((-1, data.shape[-1])) + if self._mask is None + else data[self._mask] + ) + p5 = np.percentile(reshaped_data, 5.0, axis=0) + p95 = np.percentile(reshaped_data, 95.0, axis=0) - p5 + data = (data - p5) * p95.mean() / p95 + p5.mean() + + # Select the summary statistic + avg_func = getattr(np, kwargs.pop("stat", "mean")) + + # Calculate the average + self._data = avg_func(data, axis=-1) + + @property + def is_fitted(self): + return self._data is not None + + def predict(self, *_, **kwargs): + """Return the average map.""" + + if self._data is None: + raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") + + return self._data class BaseDWIModel(BaseModel): @@ -400,30 +476,6 @@ def predict(self, gradient=None, **kwargs): return retval -class TrivialB0Model(BaseDWIModel): - """A trivial model that returns a *b=0* map always.""" - - def __init__(self, **kwargs): - """Implement object initialization.""" - super().__init__(**kwargs) - - if self._S0 is None: - raise ValueError("S0 must be provided") - - @property - def is_fitted(self): - return True - - def fit(self, data, **kwargs): - """Do nothing.""" - - def predict(self, *_, **kwargs): - """Return the *b=0* map.""" - - # No need to check fit (if not fitted, has raised already) - return self._S0 - - class AverageDWModel(BaseDWIModel): """A trivial model that returns an average map.""" From 77b8267756c69045ed7d0f68e4300dcd0aebf2c0 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 29 Aug 2024 16:02:17 +0200 Subject: [PATCH 2/5] fix: refactor package structure for better delineated models --- src/eddymotion/model/__init__.py | 10 +- src/eddymotion/model/{dipy.py => _dipy.py} | 21 +- src/eddymotion/model/base.py | 429 +-------------------- src/eddymotion/model/dmri.py | 378 ++++++++++++++++++ src/eddymotion/model/dmri_utils.py | 94 ----- src/eddymotion/model/pet.py | 138 +++++++ 6 files changed, 553 insertions(+), 517 deletions(-) rename src/eddymotion/model/{dipy.py => _dipy.py} (96%) create mode 100644 src/eddymotion/model/dmri.py delete mode 100644 src/eddymotion/model/dmri_utils.py create mode 100644 src/eddymotion/model/pet.py diff --git a/src/eddymotion/model/__init__.py b/src/eddymotion/model/__init__.py index 103f9008..149a71ec 100644 --- a/src/eddymotion/model/__init__.py +++ b/src/eddymotion/model/__init__.py @@ -23,17 +23,21 @@ """Data models.""" from eddymotion.model.base import ( + AverageModel, + ModelFactory, + TrivialModel, +) +from eddymotion.model.dmri import ( AverageDWModel, DKIModel, DTIModel, GPModel, - ModelFactory, - PETModel, - TrivialModel, ) +from eddymotion.model.pet import PETModel __all__ = ( "ModelFactory", + "AverageModel", "AverageDWModel", "DKIModel", "DTIModel", diff --git a/src/eddymotion/model/dipy.py b/src/eddymotion/model/_dipy.py similarity index 96% rename from src/eddymotion/model/dipy.py rename to src/eddymotion/model/_dipy.py index d69db579..7a97420e 100644 --- a/src/eddymotion/model/dipy.py +++ b/src/eddymotion/model/_dipy.py @@ -87,10 +87,11 @@ from __future__ import annotations +import warnings from sys import modules import numpy as np -from dipy.core.gradients import GradientTable +from dipy.core.gradients import GradientTable, gradient_table from dipy.reconst.base import ReconstModel from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import ( @@ -690,3 +691,21 @@ def set_params(self, **params): self.a = params.get("a", self.a) self.sigma_sq = params.get("sigma_sq", self.sigma_sq) return self + + +def _rasb2dipy(gradient): + gradient = np.asanyarray(gradient) + if gradient.ndim == 1: + if gradient.size != 4: + raise ValueError("Missing gradient information.") + gradient = gradient[..., np.newaxis] + + if gradient.shape[0] != 4: + gradient = gradient.T + elif gradient.shape == (4, 4): + print("Warning: make sure gradient information is not transposed!") + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + retval = gradient_table(gradient[3, :], gradient[:3, :].T) + return retval diff --git a/src/eddymotion/model/base.py b/src/eddymotion/model/base.py index fbaad99e..1bf8f01b 100644 --- a/src/eddymotion/model/base.py +++ b/src/eddymotion/model/base.py @@ -20,47 +20,12 @@ # # https://www.nipreps.org/community/licensing/ # -"""A factory class that adapts DIPY's dMRI models.""" - -import warnings +"""Base infrastructure for eddymotion's models.""" import numpy as np -from dipy.core.gradients import gradient_table -from joblib import Parallel, delayed from eddymotion.exceptions import ModelNotFittedError -DEFAULT_MIN_S0 = 1e-5 -"""Minimum value when considering the :math:`S_{0}` DWI signal.""" - -DEFAULT_MAX_S0 = 1.0 -"""Maximum value when considering the :math:`S_{0}` DWI signal.""" - -DEFAULT_MAX_BVALUE = 1000 -"""Maximum allowed value for the b-value.""" - -DEFAULT_LOWB_THRESHOLD = 50 -"""The lower bound for the b-value so that the orientation is considered a DW volume.""" - -DEFAULT_HIGHB_THRESHOLD = 10000 -"""A b-value cap for DWI data.""" - -DEFAULT_CLIP_PERCENTILE = 75 -"""Upper percentile threshold for intensity clipping.""" - -DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2 -"""Time frame tolerance in seconds.""" - - -def _exec_fit(model, data, chunk=None): - retval = model.fit(data) - return retval, chunk - - -def _exec_predict(model, chunk=None, **kwargs): - """Propagate model parameters and call predict.""" - return np.squeeze(model.predict(**kwargs)), chunk - class ModelFactory: """A factory for instantiating diffusion models.""" @@ -83,9 +48,11 @@ def init(model="DTI", **kwargs): """ if model.lower() in ("s0", "b0"): - return TrivialB0Model(S0=kwargs.pop("S0"), gtab=kwargs.pop("gtab")) + return TrivialModel(predicted=kwargs.pop("S0"), gtab=kwargs.pop("gtab")) if model.lower() in ("avgdwi", "averagedwi", "meandwi"): + from eddymotion.model.dmri import AverageDWModel + return AverageDWModel(**kwargs) if model.lower() in ("avg", "average", "mean"): @@ -146,123 +113,16 @@ def predict(self, *args, **kwargs): raise NotImplementedError("Cannot call predict() on a BaseModel instance.") -class PETModel(BaseModel): - """A PET imaging realignment model based on B-Spline approximation.""" - - __slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_n_ctrl") - - def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs): - """ - Create the B-Spline interpolating matrix. - - Parameters: - ----------- - timepoints : :obj:`list` - The timing (in sec) of each PET volume. - E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330., - 420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]`` - - n_ctrl : :obj:`int` - Number of B-Spline control points. If `None`, then one control point every - six timepoints will be used. The less control points, the smoother is the - model. - - """ - super.__init__(**kwargs) - - if timepoints is None or xlim is None: - raise TypeError("timepoints must be provided in initialization") - - self._order = order - - self._x = np.array(timepoints, dtype="float32") - self._xlim = xlim - - if self._x[0] < DEFAULT_TIMEFRAME_MIDPOINT_TOL: - raise ValueError("First frame midpoint should not be zero or negative") - if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL): - raise ValueError("Last frame midpoint should not be equal or greater than duration") - - # Calculate index coordinates in the B-Spline grid - self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1 - - # B-Spline knots - self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32") - - self._coeff = None - - @property - def is_fitted(self): - return self._coeff is not None - - def fit(self, data, **kwargs): - """Fit the model.""" - from scipy.interpolate import BSpline - from scipy.sparse.linalg import cg - - n_jobs = kwargs.pop("n_jobs", None) or 1 - - timepoints = kwargs.get("timepoints", None) or self._x - x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl - - self._datashape = data.shape[:3] - - # Convert data into V (voxels) x T (timepoints) - data = data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask] - - # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding) - A = BSpline.design_matrix(x, self._t, k=self._order) - AT = A.T - ATdotA = AT @ A - - # One single CPU - linear execution (full model) - if n_jobs == 1: - self._coeff = np.array([cg(ATdotA, AT @ v)[0] for v in data]) - return - - # Parallelize process with joblib - with Parallel(n_jobs=n_jobs) as executor: - results = executor(delayed(cg)(ATdotA, AT @ v) for v in data) - - self._coeff = np.array([r[0] for r in results]) - - def predict(self, index=None, **kwargs): - """Return the corrected volume using B-spline interpolation.""" - from scipy.interpolate import BSpline - - if index is None: - raise ValueError("A timepoint index to be simulated must be provided.") - - if not self._is_fitted: - raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") - - # Project sample timing into B-Spline coordinates - x = (index / self._xlim) * self._n_ctrl - A = BSpline.design_matrix(x, self._t, k=self._order) - - # A is 1 (num. timepoints) x C (num. coeff) - # self._coeff is V (num. voxels) x K - 4 - predicted = np.squeeze(A @ self._coeff.T) - - if self._mask is None: - return predicted.reshape(self._datashape) - - retval = np.zeros(self._datashape, dtype="float32") - retval[self._mask] = predicted - return retval - class TrivialModel(BaseModel): """A trivial model that returns a given map always.""" - __slots__ = ("_predicted", ) + __slots__ = ("_predicted",) def __init__(self, predicted=None, **kwargs): """Implement object initialization.""" if predicted is None: - raise TypeError( - "This model requires the predicted map at initialization" - ) - + raise TypeError("This model requires the predicted map at initialization") + super().__init__(**kwargs) self._predicted = predicted self._datashape = predicted.shape @@ -284,7 +144,7 @@ def predict(self, *_, **kwargs): class AverageModel(BaseModel): """A trivial model that returns an average map.""" - __slots__ = ("_data", ) + __slots__ = ("_data",) def __init__(self, **kwargs): """Initialize a new model.""" @@ -296,11 +156,9 @@ def fit(self, data, **kwargs): # Regress out global signal differences if kwargs.pop("equalize", False): - data = data.copy().astype('float32') + data = data.copy().astype("float32") reshaped_data = ( - data.reshape((-1, data.shape[-1])) - if self._mask is None - else data[self._mask] + data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask] ) p5 = np.percentile(reshaped_data, 5.0, axis=0) p95 = np.percentile(reshaped_data, 95.0, axis=0) - p5 @@ -323,270 +181,3 @@ def predict(self, *_, **kwargs): raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") return self._data - - -class BaseDWIModel(BaseModel): - """Interface and default methods for DWI models.""" - - __slots__ = ( - "_gtab", - "_S0", - "_b_max", - "_model_class", # Defining a model class, DIPY models are instantiated automagically - "_modelargs", - ) - - def __init__(self, gtab, S0=None, b_max=None, **kwargs): - """Initialization. - - Parameters - ---------- - gtab : :obj:`numpy.ndarray` - An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and - columns are b-vector components and corresponding b-value, respectively. - S0 : :obj:`numpy.ndarray` - :math:`S_{0}` signal. - b_max : :obj:`int` - Maximum value to cap b-values. - - """ - - super().__init__(**kwargs) - - # Setup B0 map - self._S0 = None - if S0 is not None: - self._S0 = np.clip( - S0.astype("float32") / S0.max(), - a_min=DEFAULT_MIN_S0, - a_max=DEFAULT_MAX_S0, - ) - - # Cap b-values, if requested - self._gtab = gtab - self._b_max = None - if b_max and b_max > DEFAULT_MAX_BVALUE: - # Saturate b-values at b_max, since signal stops dropping - self._gtab[-1, self._gtab[-1] > b_max] = b_max - # A possibly good alternative is completely remove very high b-values - # bval_mask = gtab[-1] < b_max - # data = data[..., bval_mask] - # gtab = gtab[:, bval_mask] - self._b_max = b_max - - kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs} - - # DIPY models (or one with a fully-compliant interface) - model_str = getattr(self, "_model_class", None) - if model_str: - from importlib import import_module - - module_name, class_name = model_str.rsplit(".", 1) - self._model = getattr( - import_module(module_name), - class_name, - )(_rasb2dipy(gtab), **kwargs) - - def fit(self, data, n_jobs=None, **kwargs): - """Fit the model chunk-by-chunk asynchronously""" - n_jobs = n_jobs or 1 - - self._datashape = data.shape - - # Select voxels within mask or just unravel 3D if no mask - data = ( - data[self._mask, ...] if self._mask is not None else data.reshape(-1, data.shape[-1]) - ) - - # One single CPU - linear execution (full model) - if n_jobs == 1: - self._model, _ = _exec_fit(self._model, data) - return - - # Split data into chunks of group of slices - data_chunks = np.array_split(data, n_jobs) - - self._models = [None] * n_jobs - - # Parallelize process with joblib - with Parallel(n_jobs=n_jobs) as executor: - results = executor( - delayed(_exec_fit)(self._model, dchunk, i) for i, dchunk in enumerate(data_chunks) - ) - for submodel, index in results: - self._models[index] = submodel - - self._is_fitted = True - self._model = None # Preempt further actions on the model - - def predict(self, gradient=None, **kwargs): - """Predict asynchronously chunk-by-chunk the diffusion signal.""" - - if gradient is None: - raise ValueError("A gradient to be simulated (b-vector, b-value) must be provided") - - if not self._is_fitted: - raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") - - gradient = np.array(gradient) # Tuples are unmutable - - # Cap the b-value if b_max is defined - gradient[-1] = min(gradient[-1], self._b_max or gradient[-1]) - - gradient = _rasb2dipy(gradient) - - S0 = None - if self._S0 is not None: - S0 = ( - self._S0[self._mask, ...] - if self._mask is not None - else self._S0.reshape(-1, self._S0.shape[-1]) - ) - - n_models = len(self._models) if self._model is None and self._models else 1 - - if n_models == 1: - predicted, _ = _exec_predict(self._model, **(kwargs | {"gtab": gradient, "S0": S0})) - else: - S0 = np.array_split(S0, n_models) if S0 is not None else [None] * n_models - - predicted = [None] * n_models - - # Parallelize process with joblib - with Parallel(n_jobs=n_models) as executor: - results = executor( - delayed(_exec_predict)( - model, - chunk=i, - **(kwargs | {"gtab": gradient, "S0": S0[i]}), - ) - for i, model in enumerate(self._models) - ) - for subprediction, index in results: - predicted[index] = subprediction - - predicted = np.hstack(predicted) - - if self._mask is not None: - retval = np.zeros_like(self._mask, dtype="float32") - retval[self._mask, ...] = predicted - else: - retval = predicted.reshape(self._datashape[:-1]) - - return retval - - -class AverageDWModel(BaseDWIModel): - """A trivial model that returns an average map.""" - - __slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat", "_is_fitted") - - def __init__(self, **kwargs): - r""" - Implement object initialization. - - Parameters - ---------- - th_low : :obj:`numbers.Number` - A lower bound for the b-value corresponding to the diffusion weighted images - that will be averaged. - th_high : :obj:`numbers.Number` - An upper bound for the b-value corresponding to the diffusion weighted images - that will be averaged. - bias : :obj:`bool` - Whether the overall distribution of each diffusion weighted image will be - standardized and centered around the - :data:`src.eddymotion.model.base.DEFAULT_CLIP_PERCENTILE` percentile. - stat : :obj:`str` - Whether the summary statistic to apply is ``"mean"`` or ``"median"``. - - """ - super().__init__(**kwargs) - - self._th_low = kwargs.get("th_low", DEFAULT_LOWB_THRESHOLD) - self._th_high = kwargs.get("th_high", DEFAULT_HIGHB_THRESHOLD) - self._bias = kwargs.get("bias", True) - self._stat = kwargs.get("stat", "median") - self._data = None - - def fit(self, data, **kwargs): - """Calculate the average.""" - - if (gtab := kwargs.pop("gtab", None)) is None: - raise ValueError("A gradient table must be provided.") - - # Select the interval of b-values for which DWIs will be averaged - b_mask = ( - ((gtab[3] >= self._th_low) & (gtab[3] <= self._th_high)) - if gtab is not None - else np.ones((data.shape[-1],), dtype=bool) - ) - shells = data[..., b_mask] - - # Regress out global signal differences - if self._bias: - centers = np.median(shells, axis=(0, 1, 2)) - reference = np.percentile(centers[centers >= 1.0], DEFAULT_CLIP_PERCENTILE) - centers[centers < 1.0] = reference - drift = reference / centers - shells = shells * drift - - # Select the summary statistic - avg_func = np.median if self._stat == "median" else np.mean - # Calculate the average - self._data = avg_func(shells, axis=-1) - self._is_fitted = True - - def predict(self, *_, **kwargs): - """Return the average map.""" - - if not self._is_fitted: - raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") - - return self._data - - -class DTIModel(BaseDWIModel): - """A wrapper of :obj:`dipy.reconst.dti.TensorModel`.""" - - _modelargs = ( - "min_signal", - "return_S0_hat", - "fit_method", - "weighting", - "sigma", - "jac", - ) - _model_class = "dipy.reconst.dti.TensorModel" - - -class DKIModel(BaseDWIModel): - """A wrapper of :obj:`dipy.reconst.dki.DiffusionKurtosisModel`.""" - - _modelargs = DTIModel._modelargs - _model_class = "dipy.reconst.dki.DiffusionKurtosisModel" - - -class GPModel(BaseDWIModel): - """A wrapper of :obj:`~eddymotion.model.dipy.GaussianProcessModel`.""" - - _modelargs = ("kernel_model",) - _model_class = "eddymotion.model.dipy.GaussianProcessModel" - - -def _rasb2dipy(gradient): - gradient = np.asanyarray(gradient) - if gradient.ndim == 1: - if gradient.size != 4: - raise ValueError("Missing gradient information.") - gradient = gradient[..., np.newaxis] - - if gradient.shape[0] != 4: - gradient = gradient.T - elif gradient.shape == (4, 4): - print("Warning: make sure gradient information is not transposed!") - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning) - retval = gradient_table(gradient[3, :], gradient[:3, :].T) - return retval diff --git a/src/eddymotion/model/dmri.py b/src/eddymotion/model/dmri.py new file mode 100644 index 00000000..45e4e8cd --- /dev/null +++ b/src/eddymotion/model/dmri.py @@ -0,0 +1,378 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright 2024 The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY kIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# + +import numpy as np +from joblib import Parallel, delayed + +from eddymotion.exceptions import ModelNotFittedError +from eddymotion.model._dipy import _rasb2dipy +from eddymotion.model.base import BaseModel + + +def _exec_fit(model, data, chunk=None): + retval = model.fit(data) + return retval, chunk + + +def _exec_predict(model, chunk=None, **kwargs): + """Propagate model parameters and call predict.""" + return np.squeeze(model.predict(**kwargs)), chunk + + +DEFAULT_CLIP_PERCENTILE = 75 +"""Upper percentile threshold for intensity clipping.""" + +DEFAULT_MIN_S0 = 1e-5 +"""Minimum value when considering the :math:`S_{0}` DWI signal.""" + +DEFAULT_MAX_S0 = 1.0 +"""Maximum value when considering the :math:`S_{0}` DWI signal.""" + +DEFAULT_MAX_BVALUE = 1000 +"""Maximum allowed value for the b-value.""" + +DEFAULT_LOWB_THRESHOLD = 50 +"""The lower bound for the b-value so that the orientation is considered a DW volume.""" + +DEFAULT_HIGHB_THRESHOLD = 10000 +"""A b-value cap for DWI data.""" + +DEFAULT_NUM_BINS = 15 +"""Number of bins to classify b-values.""" + +DEFAULT_MULTISHELL_BIN_COUNT_THR = 7 +"""Default bin count to consider a multishell scheme.""" + +DEFAULT_MAX_BVAL = 8000 +"""Maximum b-value cap.""" + + +class BaseDWIModel(BaseModel): + """Interface and default methods for DWI models.""" + + __slots__ = ( + "_gtab", + "_S0", + "_b_max", + "_model_class", # Defining a model class, DIPY models are instantiated automagically + "_modelargs", + ) + + def __init__(self, gtab, S0=None, b_max=None, **kwargs): + """Initialization. + + Parameters + ---------- + gtab : :obj:`numpy.ndarray` + An :math:`N \times 4` table, where rows (*N*) are diffusion gradients and + columns are b-vector components and corresponding b-value, respectively. + S0 : :obj:`numpy.ndarray` + :math:`S_{0}` signal. + b_max : :obj:`int` + Maximum value to cap b-values. + + """ + + super().__init__(**kwargs) + + # Setup B0 map + self._S0 = None + if S0 is not None: + self._S0 = np.clip( + S0.astype("float32") / S0.max(), + a_min=DEFAULT_MIN_S0, + a_max=DEFAULT_MAX_S0, + ) + + # Cap b-values, if requested + self._gtab = gtab + self._b_max = None + if b_max and b_max > DEFAULT_MAX_BVALUE: + # Saturate b-values at b_max, since signal stops dropping + self._gtab[-1, self._gtab[-1] > b_max] = b_max + # A possibly good alternative is completely remove very high b-values + # bval_mask = gtab[-1] < b_max + # data = data[..., bval_mask] + # gtab = gtab[:, bval_mask] + self._b_max = b_max + + kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs} + + # DIPY models (or one with a fully-compliant interface) + model_str = getattr(self, "_model_class", None) + if model_str: + from importlib import import_module + + module_name, class_name = model_str.rsplit(".", 1) + self._model = getattr( + import_module(module_name), + class_name, + )(_rasb2dipy(gtab), **kwargs) + + def fit(self, data, n_jobs=None, **kwargs): + """Fit the model chunk-by-chunk asynchronously""" + n_jobs = n_jobs or 1 + + self._datashape = data.shape + + # Select voxels within mask or just unravel 3D if no mask + data = ( + data[self._mask, ...] if self._mask is not None else data.reshape(-1, data.shape[-1]) + ) + + # One single CPU - linear execution (full model) + if n_jobs == 1: + self._model, _ = _exec_fit(self._model, data) + return + + # Split data into chunks of group of slices + data_chunks = np.array_split(data, n_jobs) + + self._models = [None] * n_jobs + + # Parallelize process with joblib + with Parallel(n_jobs=n_jobs) as executor: + results = executor( + delayed(_exec_fit)(self._model, dchunk, i) for i, dchunk in enumerate(data_chunks) + ) + for submodel, index in results: + self._models[index] = submodel + + self._is_fitted = True + self._model = None # Preempt further actions on the model + + def predict(self, gradient=None, **kwargs): + """Predict asynchronously chunk-by-chunk the diffusion signal.""" + + if gradient is None: + raise ValueError("A gradient to be simulated (b-vector, b-value) must be provided") + + if not self._is_fitted: + raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") + + gradient = np.array(gradient) # Tuples are unmutable + + # Cap the b-value if b_max is defined + gradient[-1] = min(gradient[-1], self._b_max or gradient[-1]) + + gradient = _rasb2dipy(gradient) + + S0 = None + if self._S0 is not None: + S0 = ( + self._S0[self._mask, ...] + if self._mask is not None + else self._S0.reshape(-1, self._S0.shape[-1]) + ) + + n_models = len(self._models) if self._model is None and self._models else 1 + + if n_models == 1: + predicted, _ = _exec_predict(self._model, **(kwargs | {"gtab": gradient, "S0": S0})) + else: + S0 = np.array_split(S0, n_models) if S0 is not None else [None] * n_models + + predicted = [None] * n_models + + # Parallelize process with joblib + with Parallel(n_jobs=n_models) as executor: + results = executor( + delayed(_exec_predict)( + model, + chunk=i, + **(kwargs | {"gtab": gradient, "S0": S0[i]}), + ) + for i, model in enumerate(self._models) + ) + for subprediction, index in results: + predicted[index] = subprediction + + predicted = np.hstack(predicted) + + if self._mask is not None: + retval = np.zeros_like(self._mask, dtype="float32") + retval[self._mask, ...] = predicted + else: + retval = predicted.reshape(self._datashape[:-1]) + + return retval + + +class AverageDWModel(BaseDWIModel): + """A trivial model that returns an average map.""" + + __slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat", "_is_fitted") + + def __init__(self, **kwargs): + r""" + Implement object initialization. + + Parameters + ---------- + th_low : :obj:`numbers.Number` + A lower bound for the b-value corresponding to the diffusion weighted images + that will be averaged. + th_high : :obj:`numbers.Number` + An upper bound for the b-value corresponding to the diffusion weighted images + that will be averaged. + bias : :obj:`bool` + Whether the overall distribution of each diffusion weighted image will be + standardized and centered around the + :data:`src.eddymotion.model.base.DEFAULT_CLIP_PERCENTILE` percentile. + stat : :obj:`str` + Whether the summary statistic to apply is ``"mean"`` or ``"median"``. + + """ + super().__init__(**kwargs) + + self._th_low = kwargs.get("th_low", DEFAULT_LOWB_THRESHOLD) + self._th_high = kwargs.get("th_high", DEFAULT_HIGHB_THRESHOLD) + self._bias = kwargs.get("bias", True) + self._stat = kwargs.get("stat", "median") + self._data = None + + def fit(self, data, **kwargs): + """Calculate the average.""" + + if (gtab := kwargs.pop("gtab", None)) is None: + raise ValueError("A gradient table must be provided.") + + # Select the interval of b-values for which DWIs will be averaged + b_mask = ( + ((gtab[3] >= self._th_low) & (gtab[3] <= self._th_high)) + if gtab is not None + else np.ones((data.shape[-1],), dtype=bool) + ) + shells = data[..., b_mask] + + # Regress out global signal differences + if self._bias: + centers = np.median(shells, axis=(0, 1, 2)) + reference = np.percentile(centers[centers >= 1.0], DEFAULT_CLIP_PERCENTILE) + centers[centers < 1.0] = reference + drift = reference / centers + shells = shells * drift + + # Select the summary statistic + avg_func = np.median if self._stat == "median" else np.mean + # Calculate the average + self._data = avg_func(shells, axis=-1) + self._is_fitted = True + + def predict(self, *_, **kwargs): + """Return the average map.""" + + if not self._is_fitted: + raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") + + return self._data + + +class DTIModel(BaseDWIModel): + """A wrapper of :obj:`dipy.reconst.dti.TensorModel`.""" + + _modelargs = ( + "min_signal", + "return_S0_hat", + "fit_method", + "weighting", + "sigma", + "jac", + ) + _model_class = "dipy.reconst.dti.TensorModel" + + +class DKIModel(BaseDWIModel): + """A wrapper of :obj:`dipy.reconst.dki.DiffusionKurtosisModel`.""" + + _modelargs = DTIModel._modelargs + _model_class = "dipy.reconst.dki.DiffusionKurtosisModel" + + +class GPModel(BaseDWIModel): + """A wrapper of :obj:`~eddymotion.model.dipy.GaussianProcessModel`.""" + + _modelargs = ("kernel_model",) + _model_class = "eddymotion.model._dipy.GaussianProcessModel" + + +def find_shelling_scheme( + bvals, + num_bins=DEFAULT_NUM_BINS, + multishell_nonempty_bin_count_thr=DEFAULT_MULTISHELL_BIN_COUNT_THR, + bval_cap=DEFAULT_MAX_BVAL, +): + """ + Find the shelling scheme on the given b-values. + + Computes the histogram of the b-values according to ``num_bins`` + and depending on the nonempty bin count, classify the shelling scheme + as single-shell if they are 2 (low-b and a shell); multi-shell if they are + below the ``multishell_nonempty_bin_count_thr`` value; and DSI otherwise. + + Parameters + ---------- + bvals : :obj:`list` or :obj:`~numpy.ndarray` + List or array of b-values. + num_bins : :obj:`int`, optional + Number of bins. + multishell_nonempty_bin_count_thr : :obj:`int`, optional + Bin count to consider a multi-shell scheme. + + Returns + ------- + scheme : :obj:`str` + Shelling scheme. + bval_groups : :obj:`list` + List of grouped b-values. + bval_estimated : :obj:`list` + List of 'estimated' b-values as the median value of each b-value group. + + """ + + # Bin the b-values: use -1 as the lower bound to be able to appropriately + # include b0 values + hist, bin_edges = np.histogram(bvals, bins=num_bins, range=(-1, min(max(bvals), bval_cap))) + + # Collect values in each bin + bval_groups = [] + bval_estimated = [] + for lower, upper in zip(bin_edges[:-1], bin_edges[1:], strict=False): + # Add only if a nonempty b-values mask + if (mask := (bvals > lower) & (bvals <= upper)).sum(): + bval_groups.append(bvals[mask]) + bval_estimated.append(np.median(bvals[mask])) + + nonempty_bins = len(bval_groups) + + if nonempty_bins < 2: + raise ValueError("DWI must have at least one high-b shell") + + if nonempty_bins == 2: + scheme = "single-shell" + elif nonempty_bins < multishell_nonempty_bin_count_thr: + scheme = "multi-shell" + else: + scheme = "DSI" + + return scheme, bval_groups, bval_estimated diff --git a/src/eddymotion/model/dmri_utils.py b/src/eddymotion/model/dmri_utils.py deleted file mode 100644 index 079bbcc5..00000000 --- a/src/eddymotion/model/dmri_utils.py +++ /dev/null @@ -1,94 +0,0 @@ -# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- -# vi: set ft=python sts=4 ts=4 sw=4 et: -# -# Copyright 2024 The NiPreps Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY kIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# We support and encourage derived works from this project, please read -# about our expectations at -# -# https://www.nipreps.org/community/licensing/ -# -import numpy as np - -DEFAULT_NUM_BINS = 15 -"""Number of bins to classify b-values.""" - -DEFAULT_MULTISHELL_BIN_COUNT_THR = 7 -"""Default bin count to consider a multishell scheme.""" - -DEFAULT_MAX_BVAL = 8000 -"""Maximum b-value cap.""" - - -def find_shelling_scheme( - bvals, - num_bins=DEFAULT_NUM_BINS, - multishell_nonempty_bin_count_thr=DEFAULT_MULTISHELL_BIN_COUNT_THR, - bval_cap=DEFAULT_MAX_BVAL, -): - """ - Find the shelling scheme on the given b-values. - - Computes the histogram of the b-values according to ``num_bins`` - and depending on the nonempty bin count, classify the shelling scheme - as single-shell if they are 2 (low-b and a shell); multi-shell if they are - below the ``multishell_nonempty_bin_count_thr`` value; and DSI otherwise. - - Parameters - ---------- - bvals : :obj:`list` or :obj:`~numpy.ndarray` - List or array of b-values. - num_bins : :obj:`int`, optional - Number of bins. - multishell_nonempty_bin_count_thr : :obj:`int`, optional - Bin count to consider a multi-shell scheme. - - Returns - ------- - scheme : :obj:`str` - Shelling scheme. - bval_groups : :obj:`list` - List of grouped b-values. - bval_estimated : :obj:`list` - List of 'estimated' b-values as the median value of each b-value group. - - """ - - # Bin the b-values: use -1 as the lower bound to be able to appropriately - # include b0 values - hist, bin_edges = np.histogram(bvals, bins=num_bins, range=(-1, min(max(bvals), bval_cap))) - - # Collect values in each bin - bval_groups = [] - bval_estimated = [] - for lower, upper in zip(bin_edges[:-1], bin_edges[1:], strict=False): - # Add only if a nonempty b-values mask - if (mask := (bvals > lower) & (bvals <= upper)).sum(): - bval_groups.append(bvals[mask]) - bval_estimated.append(np.median(bvals[mask])) - - nonempty_bins = len(bval_groups) - - if nonempty_bins < 2: - raise ValueError("DWI must have at least one high-b shell") - - if nonempty_bins == 2: - scheme = "single-shell" - elif nonempty_bins < multishell_nonempty_bin_count_thr: - scheme = "multi-shell" - else: - scheme = "DSI" - - return scheme, bval_groups, bval_estimated diff --git a/src/eddymotion/model/pet.py b/src/eddymotion/model/pet.py new file mode 100644 index 00000000..f6e1299c --- /dev/null +++ b/src/eddymotion/model/pet.py @@ -0,0 +1,138 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright 2022 The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +"""Models for nuclear imaging.""" + +import numpy as np +from joblib import Parallel, delayed + +from eddymotion.exceptions import ModelNotFittedError +from eddymotion.model.base import BaseModel + +DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2 +"""Time frame tolerance in seconds.""" + + +class PETModel(BaseModel): + """A PET imaging realignment model based on B-Spline approximation.""" + + __slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_n_ctrl") + + def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs): + """ + Create the B-Spline interpolating matrix. + + Parameters: + ----------- + timepoints : :obj:`list` + The timing (in sec) of each PET volume. + E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330., + 420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]`` + + n_ctrl : :obj:`int` + Number of B-Spline control points. If `None`, then one control point every + six timepoints will be used. The less control points, the smoother is the + model. + + """ + super.__init__(**kwargs) + + if timepoints is None or xlim is None: + raise TypeError("timepoints must be provided in initialization") + + self._order = order + + self._x = np.array(timepoints, dtype="float32") + self._xlim = xlim + + if self._x[0] < DEFAULT_TIMEFRAME_MIDPOINT_TOL: + raise ValueError("First frame midpoint should not be zero or negative") + if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL): + raise ValueError("Last frame midpoint should not be equal or greater than duration") + + # Calculate index coordinates in the B-Spline grid + self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1 + + # B-Spline knots + self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32") + + self._coeff = None + + @property + def is_fitted(self): + return self._coeff is not None + + def fit(self, data, **kwargs): + """Fit the model.""" + from scipy.interpolate import BSpline + from scipy.sparse.linalg import cg + + n_jobs = kwargs.pop("n_jobs", None) or 1 + + timepoints = kwargs.get("timepoints", None) or self._x + x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl + + self._datashape = data.shape[:3] + + # Convert data into V (voxels) x T (timepoints) + data = data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask] + + # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding) + A = BSpline.design_matrix(x, self._t, k=self._order) + AT = A.T + ATdotA = AT @ A + + # One single CPU - linear execution (full model) + if n_jobs == 1: + self._coeff = np.array([cg(ATdotA, AT @ v)[0] for v in data]) + return + + # Parallelize process with joblib + with Parallel(n_jobs=n_jobs) as executor: + results = executor(delayed(cg)(ATdotA, AT @ v) for v in data) + + self._coeff = np.array([r[0] for r in results]) + + def predict(self, index=None, **kwargs): + """Return the corrected volume using B-spline interpolation.""" + from scipy.interpolate import BSpline + + if index is None: + raise ValueError("A timepoint index to be simulated must be provided.") + + if not self._is_fitted: + raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") + + # Project sample timing into B-Spline coordinates + x = (index / self._xlim) * self._n_ctrl + A = BSpline.design_matrix(x, self._t, k=self._order) + + # A is 1 (num. timepoints) x C (num. coeff) + # self._coeff is V (num. voxels) x K - 4 + predicted = np.squeeze(A @ self._coeff.T) + + if self._mask is None: + return predicted.reshape(self._datashape) + + retval = np.zeros(self._datashape, dtype="float32") + retval[self._mask] = predicted + return retval From 79cac32035b1169001813d6228f1a37eaac92cbb Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 29 Aug 2024 17:05:20 +0200 Subject: [PATCH 3/5] fix: revise tests --- test/test_dmri.py | 613 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 613 insertions(+) diff --git a/test/test_dmri.py b/test/test_dmri.py index dde0a44d..e0ff158d 100644 --- a/test/test_dmri.py +++ b/test/test_dmri.py @@ -27,6 +27,9 @@ import pytest from eddymotion.data.dmri import load +from eddymotion.model.dmri import ( + find_shelling_scheme, +) def _create_dwi_random_dataobj(): @@ -193,3 +196,613 @@ def test_equality_operator(tmp_path): # Symmetric equality assert dwi_obj == round_trip_dwi_obj assert round_trip_dwi_obj == dwi_obj + + +@pytest.mark.parametrize( + ("bvals", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"), + [ + ( + np.asarray( + [ + 5, + 300, + 300, + 300, + 300, + 300, + 305, + 1005, + 995, + 1000, + 1000, + 1005, + 1000, + 1000, + 1005, + 995, + 1000, + 1005, + 5, + 995, + 1000, + 1000, + 995, + 1005, + 995, + 1000, + 995, + 995, + 2005, + 2000, + 2005, + 2005, + 1995, + 2000, + 2005, + 2000, + 1995, + 2005, + 5, + 1995, + 2005, + 1995, + 1995, + 2005, + 2005, + 1995, + 2000, + 2000, + 2000, + 1995, + 2000, + 2000, + 2005, + 2005, + 1995, + 2005, + 2005, + 1990, + 1995, + 1995, + 1995, + 2005, + 2000, + 1990, + 2010, + 5, + ] + ), + "multi-shell", + [ + np.asarray([5, 5, 5, 5]), + np.asarray([300, 300, 300, 300, 300, 305]), + np.asarray( + [ + 1005, + 995, + 1000, + 1000, + 1005, + 1000, + 1000, + 1005, + 995, + 1000, + 1005, + 995, + 1000, + 1000, + 995, + 1005, + 995, + 1000, + 995, + 995, + ] + ), + np.asarray( + [ + 2005, + 2000, + 2005, + 2005, + 1995, + 2000, + 2005, + 2000, + 1995, + 2005, + 1995, + 2005, + 1995, + 1995, + 2005, + 2005, + 1995, + 2000, + 2000, + 2000, + 1995, + 2000, + 2000, + 2005, + 2005, + 1995, + 2005, + 2005, + 1990, + 1995, + 1995, + 1995, + 2005, + 2000, + 1990, + 2010, + ] + ), + ], + [5, 300, 1000, 2000], + ), + ], +) +def test_find_shelling_scheme_array(bvals, exp_scheme, exp_bval_groups, exp_bval_estimated): + obt_scheme, obt_bval_groups, obt_bval_estimated = find_shelling_scheme(bvals) + assert obt_scheme == exp_scheme + assert all( + np.allclose(obt_arr, exp_arr) + for obt_arr, exp_arr in zip(obt_bval_groups, exp_bval_groups, strict=True) + ) + assert np.allclose(obt_bval_estimated, exp_bval_estimated) + + +@pytest.mark.parametrize( + ("dwi_btable", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"), + [ + ( + "ds000114_singleshell", + "single-shell", + [ + np.asarray([0, 0, 0, 0, 0, 0, 0]), + np.asarray( + [ + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + ] + ), + ], + [0.0, 1000.0], + ), + ( + "hcph_multishell", + "multi-shell", + [ + np.asarray([0, 0, 0, 0, 0, 0]), + np.asarray([700, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700]), + np.asarray( + [ + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + 1000, + ] + ), + np.asarray( + [ + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + 2000, + ] + ), + np.asarray( + [ + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + 3000, + ] + ), + ], + [0.0, 700.0, 1000.0, 2000.0, 3000.0], + ), + ( + "ds004737_dsi", + "DSI", + [ + np.asarray([5, 5, 5, 5, 5, 5, 5, 5, 5]), + np.asarray([995, 995, 800, 800, 995, 995, 795, 995]), + np.asarray([1195, 1195, 1195, 1195, 1000, 1195, 1195, 1000]), + np.asarray([1595, 1595, 1595, 1600.0]), + np.asarray( + [ + 1800, + 1795, + 1795, + 1790, + 1995, + 1800, + 1795, + 1990, + 1990, + 1795, + 1990, + 1795, + 1795, + 1995, + ] + ), + np.asarray([2190, 2195, 2190, 2195, 2000, 2000, 2000, 2195, 2195, 2190]), + np.asarray([2590, 2595, 2600, 2395, 2595, 2600, 2395]), + np.array([2795, 2790, 2795, 2795, 2790, 2795, 2795, 2790, 2795]), + np.array([3590, 3395, 3595, 3595, 3395, 3395, 3400]), + np.array([3790, 3790]), + np.array([4195, 4195]), + np.array([4390, 4395, 4390]), + np.array( + [ + 4790, + 4990, + 4990, + 5000, + 5000, + 4990, + 4795, + 4985, + 5000, + 4795, + 5000, + 4990, + 4990, + 4790, + 5000, + 4990, + 4795, + 4795, + 4990, + 5000, + 4990, + ] + ), + ], + [ + 5.0, + 995.0, + 1195.0, + 1595.0, + 1797.5, + 2190.0, + 2595.0, + 2795.0, + 3400.0, + 3790.0, + 4195.0, + 4390.0, + 4990.0, + ], + ), + ], +) +def test_find_shelling_scheme_files( + dwi_btable, exp_scheme, exp_bval_groups, exp_bval_estimated, repodata +): + bvals = np.loadtxt(repodata / f"{dwi_btable}.bval") + + obt_scheme, obt_bval_groups, obt_bval_estimated = find_shelling_scheme(bvals) + assert obt_scheme == exp_scheme + assert all( + np.allclose(obt_arr, exp_arr) + for obt_arr, exp_arr in zip(obt_bval_groups, exp_bval_groups, strict=True) + ) + assert np.allclose(obt_bval_estimated, exp_bval_estimated) From 1d22912e9ddeb7cc7e508d18086c597540b76b7e Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 29 Aug 2024 17:38:32 +0200 Subject: [PATCH 4/5] fix: revise tests, all green locally --- test/test_dipy.py | 2 +- test/test_dmri_utils.py | 639 ---------------------------------------- test/test_model.py | 16 +- 3 files changed, 9 insertions(+), 648 deletions(-) delete mode 100644 test/test_dmri_utils.py diff --git a/test/test_dipy.py b/test/test_dipy.py index cd62415a..54c91006 100644 --- a/test/test_dipy.py +++ b/test/test_dipy.py @@ -27,7 +27,7 @@ from dipy.core.gradients import gradient_table from dipy.io import read_bvals_bvecs -from eddymotion.model.dipy import ( +from eddymotion.model._dipy import ( PairwiseOrientationKernel, compute_exponential_covariance, compute_pairwise_angles, diff --git a/test/test_dmri_utils.py b/test/test_dmri_utils.py deleted file mode 100644 index f5738c22..00000000 --- a/test/test_dmri_utils.py +++ /dev/null @@ -1,639 +0,0 @@ -# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- -# vi: set ft=python sts=4 ts=4 sw=4 et: -# -# Copyright 2024 The NiPreps Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# We support and encourage derived works from this project, please read -# about our expectations at -# -# https://www.nipreps.org/community/licensing/ -# - -import numpy as np -import pytest - -from eddymotion.model.dmri_utils import ( - find_shelling_scheme, -) - - -@pytest.mark.parametrize( - ("bvals", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"), - [ - ( - np.asarray( - [ - 5, - 300, - 300, - 300, - 300, - 300, - 305, - 1005, - 995, - 1000, - 1000, - 1005, - 1000, - 1000, - 1005, - 995, - 1000, - 1005, - 5, - 995, - 1000, - 1000, - 995, - 1005, - 995, - 1000, - 995, - 995, - 2005, - 2000, - 2005, - 2005, - 1995, - 2000, - 2005, - 2000, - 1995, - 2005, - 5, - 1995, - 2005, - 1995, - 1995, - 2005, - 2005, - 1995, - 2000, - 2000, - 2000, - 1995, - 2000, - 2000, - 2005, - 2005, - 1995, - 2005, - 2005, - 1990, - 1995, - 1995, - 1995, - 2005, - 2000, - 1990, - 2010, - 5, - ] - ), - "multi-shell", - [ - np.asarray([5, 5, 5, 5]), - np.asarray([300, 300, 300, 300, 300, 305]), - np.asarray( - [ - 1005, - 995, - 1000, - 1000, - 1005, - 1000, - 1000, - 1005, - 995, - 1000, - 1005, - 995, - 1000, - 1000, - 995, - 1005, - 995, - 1000, - 995, - 995, - ] - ), - np.asarray( - [ - 2005, - 2000, - 2005, - 2005, - 1995, - 2000, - 2005, - 2000, - 1995, - 2005, - 1995, - 2005, - 1995, - 1995, - 2005, - 2005, - 1995, - 2000, - 2000, - 2000, - 1995, - 2000, - 2000, - 2005, - 2005, - 1995, - 2005, - 2005, - 1990, - 1995, - 1995, - 1995, - 2005, - 2000, - 1990, - 2010, - ] - ), - ], - [5, 300, 1000, 2000], - ), - ], -) -def test_find_shelling_scheme_array(bvals, exp_scheme, exp_bval_groups, exp_bval_estimated): - obt_scheme, obt_bval_groups, obt_bval_estimated = find_shelling_scheme(bvals) - assert obt_scheme == exp_scheme - assert all( - np.allclose(obt_arr, exp_arr) - for obt_arr, exp_arr in zip(obt_bval_groups, exp_bval_groups, strict=True) - ) - assert np.allclose(obt_bval_estimated, exp_bval_estimated) - - -@pytest.mark.parametrize( - ("dwi_btable", "exp_scheme", "exp_bval_groups", "exp_bval_estimated"), - [ - ( - "ds000114_singleshell", - "single-shell", - [ - np.asarray([0, 0, 0, 0, 0, 0, 0]), - np.asarray( - [ - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - ] - ), - ], - [0.0, 1000.0], - ), - ( - "hcph_multishell", - "multi-shell", - [ - np.asarray([0, 0, 0, 0, 0, 0]), - np.asarray([700, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700, 700]), - np.asarray( - [ - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - 1000, - ] - ), - np.asarray( - [ - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - 2000, - ] - ), - np.asarray( - [ - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - 3000, - ] - ), - ], - [0.0, 700.0, 1000.0, 2000.0, 3000.0], - ), - ( - "ds004737_dsi", - "DSI", - [ - np.asarray([5, 5, 5, 5, 5, 5, 5, 5, 5]), - np.asarray([995, 995, 800, 800, 995, 995, 795, 995]), - np.asarray([1195, 1195, 1195, 1195, 1000, 1195, 1195, 1000]), - np.asarray([1595, 1595, 1595, 1600.0]), - np.asarray( - [ - 1800, - 1795, - 1795, - 1790, - 1995, - 1800, - 1795, - 1990, - 1990, - 1795, - 1990, - 1795, - 1795, - 1995, - ] - ), - np.asarray([2190, 2195, 2190, 2195, 2000, 2000, 2000, 2195, 2195, 2190]), - np.asarray([2590, 2595, 2600, 2395, 2595, 2600, 2395]), - np.array([2795, 2790, 2795, 2795, 2790, 2795, 2795, 2790, 2795]), - np.array([3590, 3395, 3595, 3595, 3395, 3395, 3400]), - np.array([3790, 3790]), - np.array([4195, 4195]), - np.array([4390, 4395, 4390]), - np.array( - [ - 4790, - 4990, - 4990, - 5000, - 5000, - 4990, - 4795, - 4985, - 5000, - 4795, - 5000, - 4990, - 4990, - 4790, - 5000, - 4990, - 4795, - 4795, - 4990, - 5000, - 4990, - ] - ), - ], - [ - 5.0, - 995.0, - 1195.0, - 1595.0, - 1797.5, - 2190.0, - 2595.0, - 2795.0, - 3400.0, - 3790.0, - 4195.0, - 4390.0, - 4990.0, - ], - ), - ], -) -def test_find_shelling_scheme_files( - dwi_btable, exp_scheme, exp_bval_groups, exp_bval_estimated, repodata -): - bvals = np.loadtxt(repodata / f"{dwi_btable}.bval") - - obt_scheme, obt_bval_groups, obt_bval_estimated = find_shelling_scheme(bvals) - assert obt_scheme == exp_scheme - assert all( - np.allclose(obt_arr, exp_arr) - for obt_arr, exp_arr in zip(obt_bval_groups, exp_bval_groups, strict=True) - ) - assert np.allclose(obt_bval_estimated, exp_bval_estimated) diff --git a/test/test_model.py b/test/test_model.py index ba60ebfe..2173aa3d 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -31,8 +31,8 @@ from eddymotion.data.dmri import DWI from eddymotion.data.splitting import lovo_split from eddymotion.exceptions import ModelNotFittedError -from eddymotion.model.base import DEFAULT_MAX_S0, DEFAULT_MIN_S0 -from eddymotion.model.dipy import GaussianProcessModel +from eddymotion.model._dipy import GaussianProcessModel +from eddymotion.model.dmri import DEFAULT_MAX_S0, DEFAULT_MIN_S0 def test_trivial_model(): @@ -40,9 +40,9 @@ def test_trivial_model(): rng = np.random.default_rng(1234) - # Should not allow initialization without a B0 - with pytest.raises(ValueError): - model.TrivialB0Model(gtab=np.eye(4)) + # Should not allow initialization without an oracle + with pytest.raises(TypeError): + model.TrivialModel() _S0 = rng.normal(size=(2, 2, 2)) @@ -52,7 +52,7 @@ def test_trivial_model(): a_max=DEFAULT_MAX_S0, ) - tmodel = model.TrivialB0Model(gtab=np.eye(4), S0=_clipped_S0) + tmodel = model.TrivialModel(predicted=_clipped_S0) data = None assert tmodel.fit(data) is None @@ -111,7 +111,7 @@ def test_average_model(): def test_gp_model(): gp = GaussianProcessModel("test") - assert isinstance(gp, model.dipy.GaussianProcessModel) + assert isinstance(gp, model._dipy.GaussianProcessModel) X, y = make_regression(n_samples=100, n_features=3, noise=0, random_state=0) @@ -150,7 +150,7 @@ def test_two_initialisations(datadir): # Initialisation via ModelFactory model2 = model.ModelFactory.init( gtab=data_train[1], - model="avg", + model="avgdwi", S0=dmri_dataset.bzero, th_low=100, th_high=1000, From f2576d3673efab755a95df489e48a7f01d09d96c Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Thu, 29 Aug 2024 18:11:10 +0200 Subject: [PATCH 5/5] doc: fix warning scalated to error (unlinked submodule) --- docs/developers.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/developers.rst b/docs/developers.rst index edb6a401..e16c028f 100644 --- a/docs/developers.rst +++ b/docs/developers.rst @@ -35,4 +35,5 @@ Information on specific functions, classes, and methods. api/eddymotion.exceptions api/eddymotion.math api/eddymotion.model + api/eddymotion.registration api/eddymotion.utils