Skip to content

Commit

Permalink
Merge pull request nipreps#3 from oesteban/patch/enh/parallel
Browse files Browse the repository at this point in the history
Code review
  • Loading branch information
sebastientourbier authored Apr 22, 2021
2 parents 7c77010 + 411f3b8 commit 8cc27f1
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 69 deletions.
7 changes: 2 additions & 5 deletions eddymotion/estimator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""A model-based algorithm for the realignment of dMRI data."""
from os import cpu_count
from pathlib import Path
from tempfile import TemporaryDirectory, mkstemp
from pkg_resources import resource_filename as pkg_fn
Expand All @@ -22,7 +21,6 @@ def fit(
align_kwargs=None,
model="b0",
seed=None,
n_threads=None,
**kwargs,
):
r"""
Expand All @@ -47,8 +45,6 @@ def fit(
seed : :obj:`int` or :obj:`bool`
Seed the random number generator (necessary when we want deterministic
estimation).
n_threads : :obj:`int`
Number of threads to fit chunk-by-chunk the data .
Return
------
Expand All @@ -73,7 +69,8 @@ def fit(

kwargs["S0"] = _advanced_clip(dwdata.bzero)

kwargs["n_threads"] = n_threads or cpu_count()
if "n_threads" in kwargs:
align_kwargs["num_threads"] = kwargs["n_threads"]

for i_iter in range(1, n_iter + 1):
index_order = np.arange(len(dwdata))
Expand Down
110 changes: 46 additions & 64 deletions eddymotion/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A factory class that adapts DIPY's dMRI models."""
from os import cpu_count
import warnings
from concurrent.futures import ThreadPoolExecutor
import asyncio
Expand Down Expand Up @@ -100,44 +101,29 @@ def predict(self, gradient, **kwargs):
class DTIModel:
"""A wrapper of :obj:`dipy.reconst.dti.TensorModel."""

__slots__ = (
"_S0",
"_mask",
"_n_threads",
"_S0_chunks",
"_mask_chunks",
"_model_chunks"
)
__slots__ = ("_model", "_S0", "_mask")

def __init__(self, gtab, S0=None, mask=None, nb_threads=1, **kwargs):
def __init__(self, gtab, S0=None, mask=None, **kwargs):
"""Instantiate the wrapped tensor model."""
from dipy.reconst.dti import TensorModel as DipyTensorModel

for k, v in kwargs.items():
if k == 'n_threads':
self._n_threads = v
else:
self._n_threads = 1
n_threads = kwargs.pop("n_threads", 0) or 0
n_threads = n_threads if n_threads > 0 else cpu_count()

self._S0 = None
self._S0_chunks = None
if S0 is not None:
self._S0 = np.clip(
S0.astype("float32") / S0.max(),
a_min=1e-5,
a_max=1.0,
)
self._S0_chunks = np.split(S0, self._n_threads, axis=2)

self._mask = None
self._mask_chunks = None
if mask is None and S0 is not None:
self._mask = mask > 0 if mask is not None else None
if self._mask is None and self._S0 is not None:
self._mask = self._S0 > np.percentile(self._S0, 35)
self._mask_chunks = np.split(self._mask, self._n_threads, axis=2)

if self._mask is not None:
self._S0 = self._S0[self._mask.astype(bool)]
self._S0_chunks = np.split(self._S0, self._n_threads, axis=2)
if self._S0 is not None:
self._S0 = self._S0[self._mask]

kwargs = {
k: v
Expand All @@ -152,85 +138,77 @@ def __init__(self, gtab, S0=None, mask=None, nb_threads=1, **kwargs):
"jac",
)
}

# Create a TensorModel for each chunk
self._model_chunks = [
DipyTensorModel(gtab, **kwargs)
for _ in range(self._n_threads)
]

@staticmethod
def fit_chunk(model_chunk, data_chunk):
"""Call model's fit."""
return model_chunk.fit(data_chunk)
self._model = [DipyTensorModel(gtab, **kwargs)] * n_threads

def fit(self, data, **kwargs):
"""Fit the model chunk-by-chunk asynchronously."""
# Mask data if provided
if self._mask is not None:
data = data[self._mask, ...]
_nthreads = len(self._model)

# All-true mask if not available
if self._mask is None:
self._mask = np.ones(data.shape[:3], dtype=bool)

# Apply mask (ensures data is now 2D)
data = data[self._mask, ...]

# Split data into chunks of group of slices (axis=2)
data_chunks = np.split(data, self._n_threads, axis=2)
# Split data into chunks of group of slices
data_chunks = np.array_split(data, _nthreads)

# Run asyncio tasks in a limited thread pool.
with ThreadPoolExecutor(max_workers=self._n_threads) as executor:
with ThreadPoolExecutor(max_workers=_nthreads) as executor:
loop = asyncio.new_event_loop()

fit_tasks = [
loop.run_in_executor(
executor,
self.fit_chunk,
self._model_chunks[i],
data_chunks[i]
_model_fit,
model,
data,
)
for i in range(self._n_threads)
for model, data in zip(self._model, data_chunks)
]

try:
self._model_chunks = loop.run_until_complete(asyncio.gather(*fit_tasks))
self._model = loop.run_until_complete(asyncio.gather(*fit_tasks))
finally:
loop.close()

@staticmethod
def predict_chunk(model_chunk, S0_chunk, gradient, step=None):
def _predict_sub(submodel, gradient, S0_chunk, step):
"""Call predict for chunk and return the predicted diffusion signal."""
return model_chunk.predict(
_rasb2dipy(gradient),
S0=S0_chunk,
step=step,
)
return submodel.predict(gradient, S0=S0_chunk, step=step)

def predict(self, gradient, step=None, **kwargs):
"""Predict asynchronously chunk-by-chunk the diffusion signal."""
_nthreads = len(self._model)
S0 = [None] * _nthreads
if self._S0 is not None:
S0 = np.array_split(self._S0, _nthreads)

# Run asyncio tasks in a limited thread pool.
with ThreadPoolExecutor(max_workers=self._n_threads) as executor:
with ThreadPoolExecutor(max_workers=_nthreads) as executor:
loop = asyncio.new_event_loop()

predict_tasks = [
loop.run_in_executor(
executor,
self.predict_chunk,
self._model_chunks[i],
self._S0_chunks[i],
gradient,
step
self._predict_sub,
model,
_rasb2dipy(gradient),
S0_chunk,
step,
)
for i in range(self._n_threads)
for model, S0_chunk in zip(self._model, S0)
]

try:
predicted = loop.run_until_complete(asyncio.gather(*predict_tasks))
finally:
loop.close()

predicted = np.squeeze(np.concatenate(predicted, axis=2))

if predicted.ndim == 3:
return predicted

predicted = np.squeeze(np.concatenate(predicted, axis=0))
retval = np.zeros_like(self._mask, dtype="float32")
retval[self._mask, ...] = predicted
retval[self._mask] = predicted
return retval


Expand Down Expand Up @@ -308,3 +286,7 @@ def _rasb2dipy(gradient):
warnings.filterwarnings("ignore", category=UserWarning)
retval = gradient_table(gradient[3, :], gradient[:3, :].T)
return retval


def _model_fit(model, data):
return model.fit(data)

0 comments on commit 8cc27f1

Please sign in to comment.