diff --git a/eddymotion/estimator.py b/eddymotion/estimator.py index 6a43a7f3..f12631f1 100644 --- a/eddymotion/estimator.py +++ b/eddymotion/estimator.py @@ -1,5 +1,4 @@ """A model-based algorithm for the realignment of dMRI data.""" -from os import cpu_count from pathlib import Path from tempfile import TemporaryDirectory, mkstemp from pkg_resources import resource_filename as pkg_fn @@ -22,7 +21,6 @@ def fit( align_kwargs=None, model="b0", seed=None, - n_threads=None, **kwargs, ): r""" @@ -47,8 +45,6 @@ def fit( seed : :obj:`int` or :obj:`bool` Seed the random number generator (necessary when we want deterministic estimation). - n_threads : :obj:`int` - Number of threads to fit chunk-by-chunk the data . Return ------ @@ -73,7 +69,8 @@ def fit( kwargs["S0"] = _advanced_clip(dwdata.bzero) - kwargs["n_threads"] = n_threads or cpu_count() + if "n_threads" in kwargs: + align_kwargs["num_threads"] = kwargs["n_threads"] for i_iter in range(1, n_iter + 1): index_order = np.arange(len(dwdata)) diff --git a/eddymotion/model.py b/eddymotion/model.py index d47412fc..943be452 100644 --- a/eddymotion/model.py +++ b/eddymotion/model.py @@ -1,4 +1,5 @@ """A factory class that adapts DIPY's dMRI models.""" +from os import cpu_count import warnings from concurrent.futures import ThreadPoolExecutor import asyncio @@ -100,44 +101,29 @@ def predict(self, gradient, **kwargs): class DTIModel: """A wrapper of :obj:`dipy.reconst.dti.TensorModel.""" - __slots__ = ( - "_S0", - "_mask", - "_n_threads", - "_S0_chunks", - "_mask_chunks", - "_model_chunks" - ) + __slots__ = ("_model", "_S0", "_mask") - def __init__(self, gtab, S0=None, mask=None, nb_threads=1, **kwargs): + def __init__(self, gtab, S0=None, mask=None, **kwargs): """Instantiate the wrapped tensor model.""" from dipy.reconst.dti import TensorModel as DipyTensorModel - for k, v in kwargs.items(): - if k == 'n_threads': - self._n_threads = v - else: - self._n_threads = 1 + n_threads = kwargs.pop("n_threads", 0) or 0 + n_threads = n_threads if n_threads > 0 else cpu_count() self._S0 = None - self._S0_chunks = None if S0 is not None: self._S0 = np.clip( S0.astype("float32") / S0.max(), a_min=1e-5, a_max=1.0, ) - self._S0_chunks = np.split(S0, self._n_threads, axis=2) - self._mask = None - self._mask_chunks = None - if mask is None and S0 is not None: + self._mask = mask > 0 if mask is not None else None + if self._mask is None and self._S0 is not None: self._mask = self._S0 > np.percentile(self._S0, 35) - self._mask_chunks = np.split(self._mask, self._n_threads, axis=2) - if self._mask is not None: - self._S0 = self._S0[self._mask.astype(bool)] - self._S0_chunks = np.split(self._S0, self._n_threads, axis=2) + if self._S0 is not None: + self._S0 = self._S0[self._mask] kwargs = { k: v @@ -152,71 +138,67 @@ def __init__(self, gtab, S0=None, mask=None, nb_threads=1, **kwargs): "jac", ) } - - # Create a TensorModel for each chunk - self._model_chunks = [ - DipyTensorModel(gtab, **kwargs) - for _ in range(self._n_threads) - ] - - @staticmethod - def fit_chunk(model_chunk, data_chunk): - """Call model's fit.""" - return model_chunk.fit(data_chunk) + self._model = [DipyTensorModel(gtab, **kwargs)] * n_threads def fit(self, data, **kwargs): """Fit the model chunk-by-chunk asynchronously.""" - # Mask data if provided - if self._mask is not None: - data = data[self._mask, ...] + _nthreads = len(self._model) + + # All-true mask if not available + if self._mask is None: + self._mask = np.ones(data.shape[:3], dtype=bool) + + # Apply mask (ensures data is now 2D) + data = data[self._mask, ...] - # Split data into chunks of group of slices (axis=2) - data_chunks = np.split(data, self._n_threads, axis=2) + # Split data into chunks of group of slices + data_chunks = np.array_split(data, _nthreads) # Run asyncio tasks in a limited thread pool. - with ThreadPoolExecutor(max_workers=self._n_threads) as executor: + with ThreadPoolExecutor(max_workers=_nthreads) as executor: loop = asyncio.new_event_loop() fit_tasks = [ loop.run_in_executor( executor, - self.fit_chunk, - self._model_chunks[i], - data_chunks[i] + _model_fit, + model, + data, ) - for i in range(self._n_threads) + for model, data in zip(self._model, data_chunks) ] try: - self._model_chunks = loop.run_until_complete(asyncio.gather(*fit_tasks)) + self._model = loop.run_until_complete(asyncio.gather(*fit_tasks)) finally: loop.close() @staticmethod - def predict_chunk(model_chunk, S0_chunk, gradient, step=None): + def _predict_sub(submodel, gradient, S0_chunk, step): """Call predict for chunk and return the predicted diffusion signal.""" - return model_chunk.predict( - _rasb2dipy(gradient), - S0=S0_chunk, - step=step, - ) + return submodel.predict(gradient, S0=S0_chunk, step=step) def predict(self, gradient, step=None, **kwargs): """Predict asynchronously chunk-by-chunk the diffusion signal.""" + _nthreads = len(self._model) + S0 = [None] * _nthreads + if self._S0 is not None: + S0 = np.array_split(self._S0, _nthreads) + # Run asyncio tasks in a limited thread pool. - with ThreadPoolExecutor(max_workers=self._n_threads) as executor: + with ThreadPoolExecutor(max_workers=_nthreads) as executor: loop = asyncio.new_event_loop() predict_tasks = [ loop.run_in_executor( executor, - self.predict_chunk, - self._model_chunks[i], - self._S0_chunks[i], - gradient, - step + self._predict_sub, + model, + _rasb2dipy(gradient), + S0_chunk, + step, ) - for i in range(self._n_threads) + for model, S0_chunk in zip(self._model, S0) ] try: @@ -224,13 +206,9 @@ def predict(self, gradient, step=None, **kwargs): finally: loop.close() - predicted = np.squeeze(np.concatenate(predicted, axis=2)) - - if predicted.ndim == 3: - return predicted - + predicted = np.squeeze(np.concatenate(predicted, axis=0)) retval = np.zeros_like(self._mask, dtype="float32") - retval[self._mask, ...] = predicted + retval[self._mask] = predicted return retval @@ -308,3 +286,7 @@ def _rasb2dipy(gradient): warnings.filterwarnings("ignore", category=UserWarning) retval = gradient_table(gradient[3, :], gradient[:3, :].T) return retval + + +def _model_fit(model, data): + return model.fit(data)