From 411f3b8e502cae09c471544e69d448ce601fb9c6 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 21 Apr 2021 22:07:26 +0200 Subject: [PATCH] enh: skim out boilerplate --- eddymotion/estimator.py | 3 ++ eddymotion/model.py | 80 +++++++++++++++++++++-------------------- 2 files changed, 44 insertions(+), 39 deletions(-) diff --git a/eddymotion/estimator.py b/eddymotion/estimator.py index 430370d8..f12631f1 100644 --- a/eddymotion/estimator.py +++ b/eddymotion/estimator.py @@ -69,6 +69,9 @@ def fit( kwargs["S0"] = _advanced_clip(dwdata.bzero) + if "n_threads" in kwargs: + align_kwargs["num_threads"] = kwargs["n_threads"] + for i_iter in range(1, n_iter + 1): index_order = np.arange(len(dwdata)) np.random.shuffle(index_order) diff --git a/eddymotion/model.py b/eddymotion/model.py index 874242d2..943be452 100644 --- a/eddymotion/model.py +++ b/eddymotion/model.py @@ -117,12 +117,13 @@ def __init__(self, gtab, S0=None, mask=None, **kwargs): a_min=1e-5, a_max=1.0, ) - self._mask = mask - if mask is None and self._S0 is not None: + + self._mask = mask > 0 if mask is not None else None + if self._mask is None and self._S0 is not None: self._mask = self._S0 > np.percentile(self._S0, 35) - if self._mask is not None and self._S0 is not None: - self._S0 = self._S0[self._mask.astype(bool)] + if self._S0 is not None: + self._S0 = self._S0[self._mask] kwargs = { k: v @@ -139,64 +140,65 @@ def __init__(self, gtab, S0=None, mask=None, **kwargs): } self._model = [DipyTensorModel(gtab, **kwargs)] * n_threads - @staticmethod - def fit_chunk(model_chunk, data_chunk): - """Call model's fit.""" - return model_chunk.fit(data_chunk) - def fit(self, data, **kwargs): """Fit the model chunk-by-chunk asynchronously.""" - # Mask data if provided - if self._mask is not None: - data = data[self._mask, ...] + _nthreads = len(self._model) + + # All-true mask if not available + if self._mask is None: + self._mask = np.ones(data.shape[:3], dtype=bool) - # Split data into chunks of group of slices (axis=2) - data_chunks = np.split(data, self._n_threads, axis=2) + # Apply mask (ensures data is now 2D) + data = data[self._mask, ...] + + # Split data into chunks of group of slices + data_chunks = np.array_split(data, _nthreads) # Run asyncio tasks in a limited thread pool. - with ThreadPoolExecutor(max_workers=self._n_threads) as executor: + with ThreadPoolExecutor(max_workers=_nthreads) as executor: loop = asyncio.new_event_loop() fit_tasks = [ loop.run_in_executor( executor, - self.fit_chunk, - self._model_chunks[i], - data_chunks[i] + _model_fit, + model, + data, ) - for i in range(self._n_threads) + for model, data in zip(self._model, data_chunks) ] try: - self._model_chunks = loop.run_until_complete(asyncio.gather(*fit_tasks)) + self._model = loop.run_until_complete(asyncio.gather(*fit_tasks)) finally: loop.close() @staticmethod - def predict_chunk(model_chunk, S0_chunk, gradient, step=None): + def _predict_sub(submodel, gradient, S0_chunk, step): """Call predict for chunk and return the predicted diffusion signal.""" - return model_chunk.predict( - _rasb2dipy(gradient), - S0=S0_chunk, - step=step, - ) + return submodel.predict(gradient, S0=S0_chunk, step=step) def predict(self, gradient, step=None, **kwargs): """Predict asynchronously chunk-by-chunk the diffusion signal.""" + _nthreads = len(self._model) + S0 = [None] * _nthreads + if self._S0 is not None: + S0 = np.array_split(self._S0, _nthreads) + # Run asyncio tasks in a limited thread pool. - with ThreadPoolExecutor(max_workers=self._n_threads) as executor: + with ThreadPoolExecutor(max_workers=_nthreads) as executor: loop = asyncio.new_event_loop() predict_tasks = [ loop.run_in_executor( executor, - self.predict_chunk, - self._model_chunks[i], - self._S0_chunks[i], - gradient, - step + self._predict_sub, + model, + _rasb2dipy(gradient), + S0_chunk, + step, ) - for i in range(self._n_threads) + for model, S0_chunk in zip(self._model, S0) ] try: @@ -204,13 +206,9 @@ def predict(self, gradient, step=None, **kwargs): finally: loop.close() - predicted = np.squeeze(np.concatenate(predicted, axis=2)) - - if predicted.ndim == 3: - return predicted - + predicted = np.squeeze(np.concatenate(predicted, axis=0)) retval = np.zeros_like(self._mask, dtype="float32") - retval[self._mask, ...] = predicted + retval[self._mask] = predicted return retval @@ -288,3 +286,7 @@ def _rasb2dipy(gradient): warnings.filterwarnings("ignore", category=UserWarning) retval = gradient_table(gradient[3, :], gradient[:3, :].T) return retval + + +def _model_fit(model, data): + return model.fit(data)