diff --git a/src/eddymotion/model/base.py b/src/eddymotion/model/base.py index fe8aca40..f318dc53 100644 --- a/src/eddymotion/model/base.py +++ b/src/eddymotion/model/base.py @@ -131,6 +131,9 @@ def __init__(self, gtab, S0=None, mask=None, b_max=None, **kwargs): module_name, class_name = model_str.rsplit(".", 1) self._model = getattr(import_module(module_name), class_name)(_rasb2dipy(gtab), **kwargs) + self._datashape = None + self._models = None + def fit(self, data, n_jobs=None, **kwargs): """Fit the model chunk-by-chunk asynchronously""" n_jobs = n_jobs or 1 @@ -257,6 +260,7 @@ def __init__(self, **kwargs): self._th_high = kwargs.get("th_high", 10000) self._bias = kwargs.get("bias", True) self._stat = kwargs.get("stat", "median") + self._data = None def fit(self, data, **kwargs): """Calculate the average.""" @@ -329,6 +333,9 @@ def __init__(self, timepoints=None, xlim=None, n_ctrl=None, mask=None, order=3, # B-Spline knots self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32") + self._shape = None + self._coeff = None + def fit(self, data, *args, **kwargs): """Fit the model.""" from scipy.interpolate import BSpline