Skip to content

Commit

Permalink
enh: skim out boilerplate
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Apr 21, 2021
1 parent 80dadf1 commit 411f3b8
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 39 deletions.
3 changes: 3 additions & 0 deletions eddymotion/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def fit(

kwargs["S0"] = _advanced_clip(dwdata.bzero)

if "n_threads" in kwargs:
align_kwargs["num_threads"] = kwargs["n_threads"]

for i_iter in range(1, n_iter + 1):
index_order = np.arange(len(dwdata))
np.random.shuffle(index_order)
Expand Down
80 changes: 41 additions & 39 deletions eddymotion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,13 @@ def __init__(self, gtab, S0=None, mask=None, **kwargs):
a_min=1e-5,
a_max=1.0,
)
self._mask = mask
if mask is None and self._S0 is not None:

self._mask = mask > 0 if mask is not None else None
if self._mask is None and self._S0 is not None:
self._mask = self._S0 > np.percentile(self._S0, 35)

if self._mask is not None and self._S0 is not None:
self._S0 = self._S0[self._mask.astype(bool)]
if self._S0 is not None:
self._S0 = self._S0[self._mask]

kwargs = {
k: v
Expand All @@ -139,78 +140,75 @@ def __init__(self, gtab, S0=None, mask=None, **kwargs):
}
self._model = [DipyTensorModel(gtab, **kwargs)] * n_threads

@staticmethod
def fit_chunk(model_chunk, data_chunk):
"""Call model's fit."""
return model_chunk.fit(data_chunk)

def fit(self, data, **kwargs):
"""Fit the model chunk-by-chunk asynchronously."""
# Mask data if provided
if self._mask is not None:
data = data[self._mask, ...]
_nthreads = len(self._model)

# All-true mask if not available
if self._mask is None:
self._mask = np.ones(data.shape[:3], dtype=bool)

# Split data into chunks of group of slices (axis=2)
data_chunks = np.split(data, self._n_threads, axis=2)
# Apply mask (ensures data is now 2D)
data = data[self._mask, ...]

# Split data into chunks of group of slices
data_chunks = np.array_split(data, _nthreads)

# Run asyncio tasks in a limited thread pool.
with ThreadPoolExecutor(max_workers=self._n_threads) as executor:
with ThreadPoolExecutor(max_workers=_nthreads) as executor:
loop = asyncio.new_event_loop()

fit_tasks = [
loop.run_in_executor(
executor,
self.fit_chunk,
self._model_chunks[i],
data_chunks[i]
_model_fit,
model,
data,
)
for i in range(self._n_threads)
for model, data in zip(self._model, data_chunks)
]

try:
self._model_chunks = loop.run_until_complete(asyncio.gather(*fit_tasks))
self._model = loop.run_until_complete(asyncio.gather(*fit_tasks))
finally:
loop.close()

@staticmethod
def predict_chunk(model_chunk, S0_chunk, gradient, step=None):
def _predict_sub(submodel, gradient, S0_chunk, step):
"""Call predict for chunk and return the predicted diffusion signal."""
return model_chunk.predict(
_rasb2dipy(gradient),
S0=S0_chunk,
step=step,
)
return submodel.predict(gradient, S0=S0_chunk, step=step)

def predict(self, gradient, step=None, **kwargs):
"""Predict asynchronously chunk-by-chunk the diffusion signal."""
_nthreads = len(self._model)
S0 = [None] * _nthreads
if self._S0 is not None:
S0 = np.array_split(self._S0, _nthreads)

# Run asyncio tasks in a limited thread pool.
with ThreadPoolExecutor(max_workers=self._n_threads) as executor:
with ThreadPoolExecutor(max_workers=_nthreads) as executor:
loop = asyncio.new_event_loop()

predict_tasks = [
loop.run_in_executor(
executor,
self.predict_chunk,
self._model_chunks[i],
self._S0_chunks[i],
gradient,
step
self._predict_sub,
model,
_rasb2dipy(gradient),
S0_chunk,
step,
)
for i in range(self._n_threads)
for model, S0_chunk in zip(self._model, S0)
]

try:
predicted = loop.run_until_complete(asyncio.gather(*predict_tasks))
finally:
loop.close()

predicted = np.squeeze(np.concatenate(predicted, axis=2))

if predicted.ndim == 3:
return predicted

predicted = np.squeeze(np.concatenate(predicted, axis=0))
retval = np.zeros_like(self._mask, dtype="float32")
retval[self._mask, ...] = predicted
retval[self._mask] = predicted
return retval


Expand Down Expand Up @@ -288,3 +286,7 @@ def _rasb2dipy(gradient):
warnings.filterwarnings("ignore", category=UserWarning)
retval = gradient_table(gradient[3, :], gradient[:3, :].T)
return retval


def _model_fit(model, data):
return model.fit(data)

0 comments on commit 411f3b8

Please sign in to comment.