Skip to content

Commit

Permalink
Merge pull request nipreps#34 from sebastientourbier/enh/parallel
Browse files Browse the repository at this point in the history
ENH: Parallelize DTI model fit and prediction
  • Loading branch information
oesteban authored Apr 22, 2021
2 parents 643d016 + 8cc27f1 commit 948b647
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 19 deletions.
3 changes: 3 additions & 0 deletions eddymotion/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def fit(

kwargs["S0"] = _advanced_clip(dwdata.bzero)

if "n_threads" in kwargs:
align_kwargs["num_threads"] = kwargs["n_threads"]

for i_iter in range(1, n_iter + 1):
index_order = np.arange(len(dwdata))
np.random.shuffle(index_order)
Expand Down
104 changes: 85 additions & 19 deletions eddymotion/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
"""A factory class that adapts DIPY's dMRI models."""
from os import cpu_count
import warnings
from concurrent.futures import ThreadPoolExecutor
import asyncio
import nest_asyncio

import numpy as np
from dipy.core.gradients import gradient_table

nest_asyncio.apply()


class ModelFactory:
"""A factory for instantiating diffusion models."""
Expand Down Expand Up @@ -100,19 +107,23 @@ def __init__(self, gtab, S0=None, mask=None, **kwargs):
"""Instantiate the wrapped tensor model."""
from dipy.reconst.dti import TensorModel as DipyTensorModel

n_threads = kwargs.pop("n_threads", 0) or 0
n_threads = n_threads if n_threads > 0 else cpu_count()

self._S0 = None
if S0 is not None:
self._S0 = np.clip(
S0.astype("float32") / S0.max(),
a_min=1e-5,
a_max=1.0,
)
self._mask = mask
if mask is None and S0 is not None:

self._mask = mask > 0 if mask is not None else None
if self._mask is None and self._S0 is not None:
self._mask = self._S0 > np.percentile(self._S0, 35)

if self._mask is not None:
self._S0 = self._S0[self._mask.astype(bool)]
if self._S0 is not None:
self._S0 = self._S0[self._mask]

kwargs = {
k: v
Expand All @@ -127,26 +138,77 @@ def __init__(self, gtab, S0=None, mask=None, **kwargs):
"jac",
)
}
self._model = DipyTensorModel(gtab, **kwargs)
self._model = [DipyTensorModel(gtab, **kwargs)] * n_threads

def fit(self, data, **kwargs):
"""Call model's fit."""
self._model = self._model.fit(data[self._mask, ...])
"""Fit the model chunk-by-chunk asynchronously."""
_nthreads = len(self._model)

# All-true mask if not available
if self._mask is None:
self._mask = np.ones(data.shape[:3], dtype=bool)

# Apply mask (ensures data is now 2D)
data = data[self._mask, ...]

# Split data into chunks of group of slices
data_chunks = np.array_split(data, _nthreads)

# Run asyncio tasks in a limited thread pool.
with ThreadPoolExecutor(max_workers=_nthreads) as executor:
loop = asyncio.new_event_loop()

fit_tasks = [
loop.run_in_executor(
executor,
_model_fit,
model,
data,
)
for model, data in zip(self._model, data_chunks)
]

try:
self._model = loop.run_until_complete(asyncio.gather(*fit_tasks))
finally:
loop.close()

def predict(self, gradient, step=None, **kwargs):
"""Propagate model parameters and call predict."""
predicted = np.squeeze(
self._model.predict(
_rasb2dipy(gradient),
S0=self._S0,
step=step,
)
)
if predicted.ndim == 3:
return predicted
@staticmethod
def _predict_sub(submodel, gradient, S0_chunk, step):
"""Call predict for chunk and return the predicted diffusion signal."""
return submodel.predict(gradient, S0=S0_chunk, step=step)

def predict(self, gradient, step=None, **kwargs):
"""Predict asynchronously chunk-by-chunk the diffusion signal."""
_nthreads = len(self._model)
S0 = [None] * _nthreads
if self._S0 is not None:
S0 = np.array_split(self._S0, _nthreads)

# Run asyncio tasks in a limited thread pool.
with ThreadPoolExecutor(max_workers=_nthreads) as executor:
loop = asyncio.new_event_loop()

predict_tasks = [
loop.run_in_executor(
executor,
self._predict_sub,
model,
_rasb2dipy(gradient),
S0_chunk,
step,
)
for model, S0_chunk in zip(self._model, S0)
]

try:
predicted = loop.run_until_complete(asyncio.gather(*predict_tasks))
finally:
loop.close()

predicted = np.squeeze(np.concatenate(predicted, axis=0))
retval = np.zeros_like(self._mask, dtype="float32")
retval[self._mask, ...] = predicted
retval[self._mask] = predicted
return retval


Expand Down Expand Up @@ -224,3 +286,7 @@ def _rasb2dipy(gradient):
warnings.filterwarnings("ignore", category=UserWarning)
retval = gradient_table(gradient[3, :], gradient[:3, :].T)
return retval


def _model_fit(model, data):
return model.fit(data)
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ python_requires = >=3.7
install_requires =
dipy>=1.3.0
scikit-image>=0.14.2
nest-asyncio>=1.5.1
test_requires =
codecov
coverage
Expand Down

0 comments on commit 948b647

Please sign in to comment.