From d63b1e205e7e2acdd878364e236964a57a820ced Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 21 Apr 2021 17:13:48 +0200 Subject: [PATCH 1/3] fix: do not expose kwarg, better initialization --- eddymotion/estimator.py | 6 ------ eddymotion/model.py | 10 ++++------ 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/eddymotion/estimator.py b/eddymotion/estimator.py index 6a43a7f3..430370d8 100644 --- a/eddymotion/estimator.py +++ b/eddymotion/estimator.py @@ -1,5 +1,4 @@ """A model-based algorithm for the realignment of dMRI data.""" -from os import cpu_count from pathlib import Path from tempfile import TemporaryDirectory, mkstemp from pkg_resources import resource_filename as pkg_fn @@ -22,7 +21,6 @@ def fit( align_kwargs=None, model="b0", seed=None, - n_threads=None, **kwargs, ): r""" @@ -47,8 +45,6 @@ def fit( seed : :obj:`int` or :obj:`bool` Seed the random number generator (necessary when we want deterministic estimation). - n_threads : :obj:`int` - Number of threads to fit chunk-by-chunk the data . Return ------ @@ -73,8 +69,6 @@ def fit( kwargs["S0"] = _advanced_clip(dwdata.bzero) - kwargs["n_threads"] = n_threads or cpu_count() - for i_iter in range(1, n_iter + 1): index_order = np.arange(len(dwdata)) np.random.shuffle(index_order) diff --git a/eddymotion/model.py b/eddymotion/model.py index d47412fc..8bc206e0 100644 --- a/eddymotion/model.py +++ b/eddymotion/model.py @@ -1,4 +1,5 @@ """A factory class that adapts DIPY's dMRI models.""" +from os import cpu_count import warnings from concurrent.futures import ThreadPoolExecutor import asyncio @@ -109,15 +110,12 @@ class DTIModel: "_model_chunks" ) - def __init__(self, gtab, S0=None, mask=None, nb_threads=1, **kwargs): + def __init__(self, gtab, S0=None, mask=None, **kwargs): """Instantiate the wrapped tensor model.""" from dipy.reconst.dti import TensorModel as DipyTensorModel - for k, v in kwargs.items(): - if k == 'n_threads': - self._n_threads = v - else: - self._n_threads = 1 + n_threads = kwargs.get("n_threads", 0) or 0 + self._n_threads = n_threads if n_threads > 0 else cpu_count() self._S0 = None self._S0_chunks = None From 80dadf147a610d4156d35109c3e1cc8bd98ee71c Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 21 Apr 2021 17:46:16 +0200 Subject: [PATCH 2/3] enh: minimize changes in ``__init__`` There's no need to duplicate efforts: * Data can be split before fit/predict. * The number of threads == number of models initialized. This also simplifies the loop to create the list of models. --- eddymotion/model.py | 32 +++++++------------------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/eddymotion/model.py b/eddymotion/model.py index 8bc206e0..874242d2 100644 --- a/eddymotion/model.py +++ b/eddymotion/model.py @@ -101,41 +101,28 @@ def predict(self, gradient, **kwargs): class DTIModel: """A wrapper of :obj:`dipy.reconst.dti.TensorModel.""" - __slots__ = ( - "_S0", - "_mask", - "_n_threads", - "_S0_chunks", - "_mask_chunks", - "_model_chunks" - ) + __slots__ = ("_model", "_S0", "_mask") def __init__(self, gtab, S0=None, mask=None, **kwargs): """Instantiate the wrapped tensor model.""" from dipy.reconst.dti import TensorModel as DipyTensorModel - n_threads = kwargs.get("n_threads", 0) or 0 - self._n_threads = n_threads if n_threads > 0 else cpu_count() + n_threads = kwargs.pop("n_threads", 0) or 0 + n_threads = n_threads if n_threads > 0 else cpu_count() self._S0 = None - self._S0_chunks = None if S0 is not None: self._S0 = np.clip( S0.astype("float32") / S0.max(), a_min=1e-5, a_max=1.0, ) - self._S0_chunks = np.split(S0, self._n_threads, axis=2) - - self._mask = None - self._mask_chunks = None - if mask is None and S0 is not None: + self._mask = mask + if mask is None and self._S0 is not None: self._mask = self._S0 > np.percentile(self._S0, 35) - self._mask_chunks = np.split(self._mask, self._n_threads, axis=2) - if self._mask is not None: + if self._mask is not None and self._S0 is not None: self._S0 = self._S0[self._mask.astype(bool)] - self._S0_chunks = np.split(self._S0, self._n_threads, axis=2) kwargs = { k: v @@ -150,12 +137,7 @@ def __init__(self, gtab, S0=None, mask=None, **kwargs): "jac", ) } - - # Create a TensorModel for each chunk - self._model_chunks = [ - DipyTensorModel(gtab, **kwargs) - for _ in range(self._n_threads) - ] + self._model = [DipyTensorModel(gtab, **kwargs)] * n_threads @staticmethod def fit_chunk(model_chunk, data_chunk): From 411f3b8e502cae09c471544e69d448ce601fb9c6 Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Wed, 21 Apr 2021 22:07:26 +0200 Subject: [PATCH 3/3] enh: skim out boilerplate --- eddymotion/estimator.py | 3 ++ eddymotion/model.py | 80 +++++++++++++++++++++-------------------- 2 files changed, 44 insertions(+), 39 deletions(-) diff --git a/eddymotion/estimator.py b/eddymotion/estimator.py index 430370d8..f12631f1 100644 --- a/eddymotion/estimator.py +++ b/eddymotion/estimator.py @@ -69,6 +69,9 @@ def fit( kwargs["S0"] = _advanced_clip(dwdata.bzero) + if "n_threads" in kwargs: + align_kwargs["num_threads"] = kwargs["n_threads"] + for i_iter in range(1, n_iter + 1): index_order = np.arange(len(dwdata)) np.random.shuffle(index_order) diff --git a/eddymotion/model.py b/eddymotion/model.py index 874242d2..943be452 100644 --- a/eddymotion/model.py +++ b/eddymotion/model.py @@ -117,12 +117,13 @@ def __init__(self, gtab, S0=None, mask=None, **kwargs): a_min=1e-5, a_max=1.0, ) - self._mask = mask - if mask is None and self._S0 is not None: + + self._mask = mask > 0 if mask is not None else None + if self._mask is None and self._S0 is not None: self._mask = self._S0 > np.percentile(self._S0, 35) - if self._mask is not None and self._S0 is not None: - self._S0 = self._S0[self._mask.astype(bool)] + if self._S0 is not None: + self._S0 = self._S0[self._mask] kwargs = { k: v @@ -139,64 +140,65 @@ def __init__(self, gtab, S0=None, mask=None, **kwargs): } self._model = [DipyTensorModel(gtab, **kwargs)] * n_threads - @staticmethod - def fit_chunk(model_chunk, data_chunk): - """Call model's fit.""" - return model_chunk.fit(data_chunk) - def fit(self, data, **kwargs): """Fit the model chunk-by-chunk asynchronously.""" - # Mask data if provided - if self._mask is not None: - data = data[self._mask, ...] + _nthreads = len(self._model) + + # All-true mask if not available + if self._mask is None: + self._mask = np.ones(data.shape[:3], dtype=bool) - # Split data into chunks of group of slices (axis=2) - data_chunks = np.split(data, self._n_threads, axis=2) + # Apply mask (ensures data is now 2D) + data = data[self._mask, ...] + + # Split data into chunks of group of slices + data_chunks = np.array_split(data, _nthreads) # Run asyncio tasks in a limited thread pool. - with ThreadPoolExecutor(max_workers=self._n_threads) as executor: + with ThreadPoolExecutor(max_workers=_nthreads) as executor: loop = asyncio.new_event_loop() fit_tasks = [ loop.run_in_executor( executor, - self.fit_chunk, - self._model_chunks[i], - data_chunks[i] + _model_fit, + model, + data, ) - for i in range(self._n_threads) + for model, data in zip(self._model, data_chunks) ] try: - self._model_chunks = loop.run_until_complete(asyncio.gather(*fit_tasks)) + self._model = loop.run_until_complete(asyncio.gather(*fit_tasks)) finally: loop.close() @staticmethod - def predict_chunk(model_chunk, S0_chunk, gradient, step=None): + def _predict_sub(submodel, gradient, S0_chunk, step): """Call predict for chunk and return the predicted diffusion signal.""" - return model_chunk.predict( - _rasb2dipy(gradient), - S0=S0_chunk, - step=step, - ) + return submodel.predict(gradient, S0=S0_chunk, step=step) def predict(self, gradient, step=None, **kwargs): """Predict asynchronously chunk-by-chunk the diffusion signal.""" + _nthreads = len(self._model) + S0 = [None] * _nthreads + if self._S0 is not None: + S0 = np.array_split(self._S0, _nthreads) + # Run asyncio tasks in a limited thread pool. - with ThreadPoolExecutor(max_workers=self._n_threads) as executor: + with ThreadPoolExecutor(max_workers=_nthreads) as executor: loop = asyncio.new_event_loop() predict_tasks = [ loop.run_in_executor( executor, - self.predict_chunk, - self._model_chunks[i], - self._S0_chunks[i], - gradient, - step + self._predict_sub, + model, + _rasb2dipy(gradient), + S0_chunk, + step, ) - for i in range(self._n_threads) + for model, S0_chunk in zip(self._model, S0) ] try: @@ -204,13 +206,9 @@ def predict(self, gradient, step=None, **kwargs): finally: loop.close() - predicted = np.squeeze(np.concatenate(predicted, axis=2)) - - if predicted.ndim == 3: - return predicted - + predicted = np.squeeze(np.concatenate(predicted, axis=0)) retval = np.zeros_like(self._mask, dtype="float32") - retval[self._mask, ...] = predicted + retval[self._mask] = predicted return retval @@ -288,3 +286,7 @@ def _rasb2dipy(gradient): warnings.filterwarnings("ignore", category=UserWarning) retval = gradient_table(gradient[3, :], gradient[:3, :].T) return retval + + +def _model_fit(model, data): + return model.fit(data)