diff --git a/eddymotion/model.py b/eddymotion/model.py index 9ce83233..a22ce278 100644 --- a/eddymotion/model.py +++ b/eddymotion/model.py @@ -94,15 +94,15 @@ def predict(self, gradient, **kwargs): class AverageDWModel: """A trivial model that returns an average map.""" - __slots__ = ("_data",) + __slots__ = ("_data", "_gtab") def __init__(self, gtab, **kwargs): """Implement object initialization.""" - return # do nothing at initialization time + self._gtab = gtab def fit(self, data, **kwargs): """Calculate the average.""" - self._data = data.mean(-1) + self._data = data[..., self._gtab[..., 3] > 50].mean(-1) def predict(self, gradient, **kwargs): """Return the average map."""