From abae1cc75f553810aeaad4d4f029207010ed3099 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 12 Jun 2024 12:22:40 -0400 Subject: [PATCH] Code review of #176 (#6) * enh: revise code * sty: ruff format --- src/eddymotion/model/base.py | 317 +++++++++++++++++++---------------- test/test_model.py | 4 +- 2 files changed, 176 insertions(+), 145 deletions(-) diff --git a/src/eddymotion/model/base.py b/src/eddymotion/model/base.py index 7a85613b..4ff28efb 100644 --- a/src/eddymotion/model/base.py +++ b/src/eddymotion/model/base.py @@ -27,6 +27,7 @@ import numpy as np from dipy.core.gradients import gradient_table from joblib import Parallel, delayed + from eddymotion.exceptions import ModelNotFittedError DEFAULT_MIN_S0 = 1e-5 @@ -56,9 +57,9 @@ def _exec_fit(model, data, chunk=None): return retval, chunk -def _exec_predict_dwi(model, gradient, chunk=None, **kwargs): +def _exec_predict(model, chunk=None, **kwargs): """Propagate model parameters and call predict.""" - return np.squeeze(model.predict(gradient, S0=kwargs.pop("S0", None))), chunk + return np.squeeze(model.predict(**kwargs)), chunk class ModelFactory: @@ -82,7 +83,7 @@ def init(model="DTI", **kwargs): """ if model.lower() in ("s0", "b0"): - return TrivialB0Model(S0=kwargs.pop("S0")) + return TrivialB0Model(S0=kwargs.pop("S0"), gtab=kwargs.pop("gtab")) if model.lower() in ("avg", "average", "mean"): return AverageDWModel(**kwargs) @@ -117,54 +118,134 @@ class BaseModel: def __init__(self, mask=None, **kwargs): """Base initialization.""" - self._model = None + # Keep model state + self._model = None # "Main" model + self._models = None # For parallel (chunked) execution self._is_fitted = False # Setup brain mask self._mask = mask self._datashape = None - self._models = None self._is_fitted = False @property def is_fitted(self): return self._is_fitted - def fit(self, data, n_jobs=None, **kwargs): - """Fit the model chunk-by-chunk asynchronously""" - n_jobs = n_jobs or 1 + def fit(self, data, **kwargs): + """Abstract member signature of fit().""" + raise NotImplementedError("Cannot call fit() on a BaseModel instance.") - self._datashape = data.shape + def predict(self, *args, **kwargs): + """Abstract member signature of predict().""" + raise NotImplementedError("Cannot call predict() on a BaseModel instance.") - # Select voxels within mask or just unravel 3D if no mask - data = ( - data[self._mask, ...] if self._mask is not None else data.reshape(-1, data.shape[-1]) - ) + +class PETModel(BaseModel): + """A PET imaging realignment model based on B-Spline approximation.""" + + __slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_n_ctrl") + + def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs): + """ + Create the B-Spline interpolating matrix. + + Parameters: + ----------- + timepoints : :obj:`list` + The timing (in sec) of each PET volume. + E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330., + 420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]`` + + n_ctrl : :obj:`int` + Number of B-Spline control points. If `None`, then one control point every + six timepoints will be used. The less control points, the smoother is the + model. + + """ + super.__init__(**kwargs) + + if timepoints is None or xlim is None: + raise TypeError("timepoints must be provided in initialization") + + self._order = order + + self._x = np.array(timepoints, dtype="float32") + self._xlim = xlim + + if self._x[0] < DEFAULT_TIMEFRAME_MIDPOINT_TOL: + raise ValueError("First frame midpoint should not be zero or negative") + if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL): + raise ValueError("Last frame midpoint should not be equal or greater than duration") + + # Calculate index coordinates in the B-Spline grid + self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1 + + # B-Spline knots + self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32") + + self._coeff = None + + @property + def is_fitted(self): + return self._coeff is not None + + def fit(self, data, **kwargs): + """Fit the model.""" + from scipy.interpolate import BSpline + from scipy.sparse.linalg import cg + + n_jobs = kwargs.pop("n_jobs", None) or 1 + + timepoints = kwargs.get("timepoints", None) or self._x + x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl + + self._datashape = data.shape[:3] + + # Convert data into V (voxels) x T (timepoints) + data = data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask] + + # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding) + A = BSpline.design_matrix(x, self._t, k=self._order) + AT = A.T + ATdotA = AT @ A # One single CPU - linear execution (full model) if n_jobs == 1: - self._model, _ = _exec_fit(self._model, data) + self._coeff = np.array([cg(ATdotA, AT @ v)[0] for v in data]) return - # Split data into chunks of group of slices - data_chunks = np.array_split(data, n_jobs) - - self._models = [None] * n_jobs - # Parallelize process with joblib with Parallel(n_jobs=n_jobs) as executor: - results = executor( - delayed(_exec_fit)(self._model, dchunk, i) for i, dchunk in enumerate(data_chunks) - ) - for submodel, index in results: - self._models[index] = submodel + results = executor(delayed(cg)(ATdotA, AT @ v) for v in data) - self._is_fitted = True - self._model = None # Preempt further actions on the model + self._coeff = np.array([r[0] for r in results]) - def predict(self, *args, **kwargs): - pass + def predict(self, index=None, **kwargs): + """Return the corrected volume using B-spline interpolation.""" + from scipy.interpolate import BSpline + + if index is None: + raise ValueError("A timepoint index to be simulated must be provided.") + + if not self._is_fitted: + raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") + + # Project sample timing into B-Spline coordinates + x = (index / self._xlim) * self._n_ctrl + A = BSpline.design_matrix(x, self._t, k=self._order) + + # A is 1 (num. timepoints) x C (num. coeff) + # self._coeff is V (num. voxels) x K - 4 + predicted = np.squeeze(A @ self._coeff.T) + + if self._mask is None: + return predicted.reshape(self._datashape) + + retval = np.zeros(self._datashape, dtype="float32") + retval[self._mask] = predicted + return retval class BaseDWIModel(BaseModel): @@ -174,6 +255,8 @@ class BaseDWIModel(BaseModel): "_gtab", "_S0", "_b_max", + "_model_class", # Defining a model class, DIPY models are instantiated automagically + "_modelargs", ) def __init__(self, gtab, S0=None, b_max=None, **kwargs): @@ -188,6 +271,7 @@ def __init__(self, gtab, S0=None, b_max=None, **kwargs): :math:`S_{0}` signal. b_max : :obj:`int` Maximum value to cap b-values. + """ super().__init__(**kwargs) @@ -215,25 +299,64 @@ def __init__(self, gtab, S0=None, b_max=None, **kwargs): kwargs = {k: v for k, v in kwargs.items() if k in self._modelargs} + # DIPY models (or one with a fully-compliant interface) model_str = getattr(self, "_model_class", None) - if not model_str: - raise TypeError("No model defined") + if model_str: + from importlib import import_module + + module_name, class_name = model_str.rsplit(".", 1) + self._model = getattr( + import_module(module_name), + class_name, + )(_rasb2dipy(gtab), **kwargs) + + def fit(self, data, n_jobs=None, **kwargs): + """Fit the model chunk-by-chunk asynchronously""" + n_jobs = n_jobs or 1 + + self._datashape = data.shape + + # Select voxels within mask or just unravel 3D if no mask + data = ( + data[self._mask, ...] if self._mask is not None else data.reshape(-1, data.shape[-1]) + ) + + # One single CPU - linear execution (full model) + if n_jobs == 1: + self._model, _ = _exec_fit(self._model, data) + return - from importlib import import_module + # Split data into chunks of group of slices + data_chunks = np.array_split(data, n_jobs) + + self._models = [None] * n_jobs + + # Parallelize process with joblib + with Parallel(n_jobs=n_jobs) as executor: + results = executor( + delayed(_exec_fit)(self._model, dchunk, i) for i, dchunk in enumerate(data_chunks) + ) + for submodel, index in results: + self._models[index] = submodel - module_name, class_name = model_str.rsplit(".", 1) - self._model = getattr(import_module(module_name), class_name)(_rasb2dipy(gtab), **kwargs) + self._is_fitted = True + self._model = None # Preempt further actions on the model - def predict(self, index, **kwargs): + def predict(self, gradient=None, **kwargs): """Predict asynchronously chunk-by-chunk the diffusion signal.""" + if gradient is None: + raise ValueError("A gradient to be simulated (b-vector, b-value) must be provided") + if not self._is_fitted: raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") - if self._b_max is not None: - index[-1] = min(index[-1], self._b_max) + gradient = np.array(gradient) # Tuples are unmutable + + # Cap the b-value if b_max is defined + gradient[-1] = min(gradient[-1], self._b_max or gradient[-1]) - self._gtab = _rasb2dipy(self._gtab) + self._gtab = _rasb2dipy(gradient) S0 = None if self._S0 is not None: @@ -246,7 +369,7 @@ def predict(self, index, **kwargs): n_models = len(self._models) if self._model is None and self._models else 1 if n_models == 1: - predicted, _ = _exec_predict_dwi(self._model, self._gtab, S0=S0, **kwargs) + predicted, _ = _exec_predict(self._model, **(kwargs | {"gtab": self._gtab, "S0": S0})) else: S0 = np.array_split(S0, n_models) if S0 is not None else [None] * n_models @@ -255,7 +378,11 @@ def predict(self, index, **kwargs): # Parallelize process with joblib with Parallel(n_jobs=n_models) as executor: results = executor( - delayed(_exec_predict_dwi)(model, self._gtab, S0=S0[i], chunk=i, **kwargs) + delayed(_exec_predict)( + model, + chunk=i, + **(kwargs | {"gtab": self._gtab, "S0": S0[i]}), + ) for i, model in enumerate(self._models) ) for subprediction, index in results: @@ -332,10 +459,14 @@ def __init__(self, **kwargs): def fit(self, data, **kwargs): """Calculate the average.""" + + if (gtab := kwargs.pop("gtab", None)) is None: + raise ValueError("A gradient table must be provided.") + # Select the interval of b-values for which DWIs will be averaged b_mask = ( - ((self._gtab[3] >= self._th_low) & (self._gtab[3] <= self._th_high)) - if self._gtab is not None + ((gtab[3] >= self._th_low) & (gtab[3] <= self._th_high)) + if gtab is not None else np.ones((data.shape[-1],), dtype=bool) ) shells = data[..., b_mask] @@ -358,7 +489,7 @@ def fit(self, data, **kwargs): def is_fitted(self): return self._is_fitted - def predict(self, gradient, **kwargs): + def predict(self, *_, **kwargs): """Return the average map.""" if not self._is_fitted: @@ -367,108 +498,6 @@ def predict(self, gradient, **kwargs): return self._data -class PETModel(BaseModel): - """A PET imaging realignment model based on B-Spline approximation.""" - - __slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_n_ctrl") - - def __init__(self, timepoints=None, xlim=None, n_ctrl=None, order=3, **kwargs): - """ - Create the B-Spline interpolating matrix. - - Parameters: - ----------- - timepoints : :obj:`list` - The timing (in sec) of each PET volume. - E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330., - 420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]`` - - n_ctrl : :obj:`int` - Number of B-Spline control points. If `None`, then one control point every - six timepoints will be used. The less control points, the smoother is the - model. - - """ - super.__init__(**kwargs) - - if timepoints is None or xlim is None: - raise TypeError("timepoints must be provided in initialization") - - self._order = order - - self._x = np.array(timepoints, dtype="float32") - self._xlim = xlim - - if self._x[0] < DEFAULT_TIMEFRAME_MIDPOINT_TOL: - raise ValueError("First frame midpoint should not be zero or negative") - if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL): - raise ValueError("Last frame midpoint should not be equal or greater than duration") - - # Calculate index coordinates in the B-Spline grid - self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1 - - # B-Spline knots - self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32") - - self._coeff = None - - @property - def is_fitted(self): - return self._coeff is not None - def fit(self, data, **kwargs): - """Fit the model.""" - from scipy.interpolate import BSpline - from scipy.sparse.linalg import cg - - n_jobs = kwargs.pop("n_jobs", None) or 1 - - timepoints = kwargs.get("timepoints", None) or self._x - x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl - - self._datashape = data.shape[:3] - - # Convert data into V (voxels) x T (timepoints) - data = data.reshape((-1, data.shape[-1])) if self._mask is None else data[self._mask] - - # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding) - A = BSpline.design_matrix(x, self._t, k=self._order) - AT = A.T - ATdotA = AT @ A - - # One single CPU - linear execution (full model) - if n_jobs == 1: - self._coeff = np.array([cg(ATdotA, AT @ v)[0] for v in data]) - return - - # Parallelize process with joblib - with Parallel(n_jobs=n_jobs) as executor: - results = executor(delayed(cg)(ATdotA, AT @ v) for v in data) - - self._coeff = np.array([r[0] for r in results]) - - def predict(self, index, **kwargs): - """Return the corrected volume using B-spline interpolation.""" - from scipy.interpolate import BSpline - - if not self._is_fitted: - raise ModelNotFittedError(f"{type(self).__name__} must be fitted before predicting") - - # Project sample timing into B-Spline coordinates - x = (index / self._xlim) * self._n_ctrl - A = BSpline.design_matrix(x, self._t, k=self._order) - - # A is 1 (num. timepoints) x C (num. coeff) - # self._coeff is V (num. voxels) x K - 4 - predicted = np.squeeze(A @ self._coeff.T) - - if self._mask is None: - return predicted.reshape(self._datashape) - - retval = np.zeros(self._datashape, dtype="float32") - retval[self._mask] = predicted - return retval - - class DTIModel(BaseDWIModel): """A wrapper of :obj:`dipy.reconst.dti.TensorModel`.""" diff --git a/test/test_model.py b/test/test_model.py index c5f319bf..753bd63f 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -42,7 +42,8 @@ def test_trivial_model(): tmodel = model.TrivialB0Model(gtab=np.eye(4), S0=_S0) - assert tmodel.fit() is None + data = None + assert tmodel.fit(data) is None assert np.all(_S0 == tmodel.predict((1, 0, 0))) @@ -106,6 +107,7 @@ def test_two_initialisations(datadir): # Direct initialisation model1 = model.AverageDWModel( + gtab=data_train[1], S0=dmri_dataset.bzero, th_low=100, th_high=1000,