Skip to content

Commit

Permalink
Merge pull request #25 from 36000/distance_transform_rois
Browse files Browse the repository at this point in the history
[ENH] Use Distance Transform for Include/Exclude Distance calculation; fix ROI dists
  • Loading branch information
36000 authored Nov 22, 2024
2 parents 5390957 + 9634624 commit bcbb266
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 167 deletions.
116 changes: 67 additions & 49 deletions AFQ/api/bundle_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,62 +946,16 @@ def copy(self):
resample_subject_to=self.resample_subject_to,
keep_in_memory=self.keep_in_memory)

def apply_to_rois(self, b_name, func, *args,
dry_run=False, apply_to_recobundles=False,
**kwargs):
def apply_to_rois(self, b_name, *args, **kwargs):
"""
Applies some transformation to all ROIs (include, exclude, end, start)
and the prob_map in a given bundle.
See: AFQ.api.bundle_dict.apply_to_roi_dict
Parameters
----------
b_name : name
bundle name of bundle whose ROIs will be transformed.
func : function
function whose first argument must be a Nifti1Image and which
returns a Nifti1Image
dry_run : bool
Whether to actually apply changes returned by `func` to the ROIs.
If has_return is False, dry_run is not used.
apply_to_recobundles : bool, optional
Whether to apply the transformation to recobundles
TRKs as well.
Default: False
*args :
Additional arguments for func
**kwargs
Optional arguments for func
Returns
-------
A dictionary where keys are
the roi type and values are the transformed ROIs.
"""
return_vals = {}
for roi_type in [
"include", "exclude",
"start", "end", "prob_map"]:
if roi_type in self._dict[b_name]:
if roi_type in ["start", "end", "prob_map"]:
return_vals[roi_type] = func(
self._dict[b_name][roi_type], *args, **kwargs)
else:
changed_rois = []
for _roi in self._dict[b_name][roi_type]:
changed_rois.append(func(
_roi, *args, **kwargs))
return_vals[roi_type] = changed_rois
if apply_to_recobundles and "recobundles" in self._dict[b_name]:
return_vals["recobundles"] = {}
for sl_type in ["sl", "centroid"]:
return_vals["recobundles"][sl_type] = func(
self._dict[b_name]["recobundles"][sl_type],
*args, **kwargs)
if not dry_run:
for roi_type, roi in return_vals.items():
self._dict[b_name][roi_type] = roi
return return_vals
return apply_to_roi_dict(self._dict[b_name], *args, **kwargs)

def _cond_load_bundle(self, b_name, dry_run=False):
"""
Expand Down Expand Up @@ -1146,3 +1100,67 @@ def __add__(self, other):
self.resample_to,
self.resample_subject_to,
self.keep_in_memory)


def apply_to_roi_dict(dict_, func, *args,
dry_run=False, apply_to_recobundles=False,
apply_to_prob_map=True,
**kwargs):
"""
Applies some transformation to all ROIs (include, exclude, end, start)
and the prob_map in a given bundle.
Parameters
----------
dict_: dict
dict describing the bundle using pyAFQ's format. An entry
in a BundleDict.
func : function
function whose first argument must be a Nifti1Image and which
returns a Nifti1Image
dry_run : bool
Whether to actually apply changes returned by `func` to the ROIs.
If has_return is False, dry_run is not used.
apply_to_recobundles : bool, optional
Whether to apply the transformation to recobundles
TRKs as well.
Default: False
apply_to_prob_map : bool, optional
Whether to apply the transformation to the prob_map.
Default: True
*args :
Additional arguments for func
**kwargs
Optional arguments for func
Returns
-------
A dictionary where keys are
the roi type and values are the transformed ROIs.
"""
return_vals = {}
roi_types = ["include", "exclude", "start", "end"]
if apply_to_prob_map:
roi_types.append("prob_map")
for roi_type in roi_types:
if roi_type in dict_:
if roi_type in ["start", "end", "prob_map"]:
return_vals[roi_type] = func(
dict_[roi_type], *args, **kwargs)
else:
changed_rois = []
for _roi in dict_[roi_type]:
changed_rois.append(func(
_roi, *args, **kwargs))
return_vals[roi_type] = changed_rois
if apply_to_recobundles and "recobundles" in dict_:
return_vals["recobundles"] = {}
for sl_type in ["sl", "centroid"]:
return_vals["recobundles"][sl_type] = func(
dict_["recobundles"][sl_type],
*args, **kwargs)
if not dry_run:
for roi_type, roi in return_vals.items():
dict_[roi_type] = roi
return return_vals
70 changes: 35 additions & 35 deletions AFQ/recognition/criteria.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import numpy as np
import logging
from time import time

import numpy as np
import nibabel as nib

from scipy.ndimage import distance_transform_edt

import dipy.tracking.streamline as dts
from dipy.utils.parallel import paramap
from dipy.segment.clustering import QuickBundles
Expand All @@ -11,6 +15,7 @@
from dipy.segment.bundles import RecoBundles
from dipy.io.stateful_tractogram import StatefulTractogram, Space

from AFQ.api.bundle_dict import apply_to_roi_dict
import AFQ.recognition.utils as abu
import AFQ.recognition.cleaning as abc
import AFQ.recognition.curvature as abv
Expand Down Expand Up @@ -132,54 +137,43 @@ def include(b_sls, bundle_def, preproc_imap, max_includes,
include_roi_tols = [preproc_imap["tol"]**2] * len(
bundle_def["include"])

include_rois = []
for include_roi in bundle_def["include"]:
include_rois.append(np.array(
np.where(include_roi.get_fdata())).T)

# with parallel segmentation, the first for loop will
# only collect streamlines and does not need tqdm
if parallel_segmentation["engine"] != "serial":
inc_results = paramap(
abr.check_sl_with_inclusion, b_sls.get_selected_sls(),
func_args=[
include_rois, include_roi_tols],
bundle_def["include"], include_roi_tols],
**parallel_segmentation)

else:
inc_results = abr.check_sls_with_inclusion(
b_sls.get_selected_sls(),
include_rois,
bundle_def["include"],
include_roi_tols)

roi_dists = -np.ones(
(len(b_sls), max_includes),
roi_closest = -np.ones(
(max_includes, len(b_sls)),
dtype=np.int32)
if flip_using_include:
to_flip = np.ones_like(accept_idx, dtype=np.bool8)
for sl_idx, inc_result in enumerate(inc_results):
sl_accepted, sl_dist = inc_result
sl_accepted, sl_closest = inc_result

if sl_accepted:
if len(sl_dist) > 1:
roi_dists[sl_idx, :len(sl_dist)] = [
np.argmin(dist, 0)[0]
for dist in sl_dist]
first_roi_idx = roi_dists[sl_idx, 0]
last_roi_idx = roi_dists[
sl_idx, len(sl_dist) - 1]
if len(sl_closest) > 1:
roi_closest[:len(sl_closest), sl_idx] = sl_closest
# Only accept SLs that, when cut, are meaningful
if (len(sl_dist) < 2) or abs(
first_roi_idx - last_roi_idx) > 1:
if (len(sl_closest) < 2) or abs(
sl_closest[0] - sl_closest[-1]) > 1:
# Flip sl if it is close to second ROI
# before its close to the first ROI
if flip_using_include:
to_flip[sl_idx] =\
first_roi_idx > last_roi_idx
sl_closest[0] > sl_closest[-1]
if to_flip[sl_idx]:
roi_dists[sl_idx, :len(sl_dist)] =\
np.flip(roi_dists[
sl_idx, :len(sl_dist)])
roi_closest[:len(sl_closest), sl_idx] =\
np.flip(sl_closest)
accept_idx[sl_idx] = 1
else:
accept_idx[sl_idx] = 1
Expand All @@ -191,7 +185,7 @@ def include(b_sls, bundle_def, preproc_imap, max_includes,
"backend", "loky") == "loky")):
from joblib.externals.loky import get_reusable_executor
get_reusable_executor().shutdown(wait=True)
b_sls.roi_dists = roi_dists
b_sls.roi_closest = roi_closest.T
if flip_using_include:
b_sls.reorient(to_flip)
b_sls.select(accept_idx, "include")
Expand Down Expand Up @@ -240,13 +234,9 @@ def exclude(b_sls, bundle_def, preproc_imap, **kwargs):
else:
exclude_roi_tols = [
preproc_imap["tol"]**2] * len(bundle_def["exclude"])
exclude_rois = []
for exclude_roi in bundle_def["exclude"]:
exclude_rois.append(np.array(
np.where(exclude_roi.get_fdata())).T)
for sl_idx, sl in enumerate(b_sls.get_selected_sls()):
if abr.check_sl_with_exclusion(
sl, exclude_rois, exclude_roi_tols):
sl, bundle_def["exclude"], exclude_roi_tols):
accept_idx[sl_idx] = 1
b_sls.select(accept_idx, "exclude")

Expand Down Expand Up @@ -332,7 +322,7 @@ def mahalanobis(b_sls, bundle_def, clip_edges, cleaning_params, **kwargs):

def run_bundle_rec_plan(
bundle_dict, tg, mapping, img, reg_template, preproc_imap,
bundle_name, bundle_idx, bundle_to_flip, bundle_roi_dists,
bundle_name, bundle_idx, bundle_to_flip, bundle_roi_closest,
bundle_decisions,
**segmentation_params):
# Warp ROIs
Expand All @@ -344,6 +334,15 @@ def run_bundle_rec_plan(
mapping,
img.affine,
apply_to_recobundles=True))
apply_to_roi_dict(
bundle_def,
lambda roi_img: nib.Nifti1Image(
distance_transform_edt(
np.where(roi_img.get_fdata() == 0, 1, 0)),
roi_img.affine),
dry_run=False,
apply_to_recobundles=False,
apply_to_prob_map=False)
logger.info(f"Time to prep ROIs: {time()-start_time}s")

b_sls = abu.SlsBeingRecognized(
Expand Down Expand Up @@ -404,8 +403,9 @@ def run_bundle_rec_plan(
bundle_decisions[
b_sls.selected_fiber_idxs,
bundle_idx] = 1
if hasattr(b_sls, "roi_dists"):
bundle_roi_dists[
if hasattr(b_sls, "roi_closest"):
bundle_roi_closest[
b_sls.selected_fiber_idxs,
bundle_idx
] = b_sls.roi_dists.copy()
bundle_idx,
:
] = b_sls.roi_closest.copy()
14 changes: 7 additions & 7 deletions AFQ/recognition/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def recognize(
bundle_to_flip = np.zeros(
(n_streamlines, len(bundle_dict)),
dtype=np.bool8)
bundle_roi_dists = -np.ones(
bundle_roi_closest = -np.ones(
(
n_streamlines,
len(bundle_dict),
Expand All @@ -180,7 +180,7 @@ def recognize(
logger.info(f"Finding Streamlines for {bundle_name}")
run_bundle_rec_plan(
bundle_dict, tg, mapping, img, reg_template, preproc_imap,
bundle_name, bundle_idx, bundle_to_flip, bundle_roi_dists,
bundle_name, bundle_idx, bundle_to_flip, bundle_roi_closest,
bundle_decisions,
clip_edges=clip_edges,
parallel_segmentation=parallel_segmentation,
Expand Down Expand Up @@ -233,22 +233,22 @@ def recognize(
# Use a list here, because ArraySequence doesn't support item
# assignment:
select_sl = list(tg.streamlines[select_idx])
roi_dists = bundle_roi_dists[select_idx, bundle_idx, :]
roi_closest = bundle_roi_closest[select_idx, bundle_idx, :]
n_includes = len(bundle_dict.get_b_info(
bundle).get("include", []))
if clip_edges and n_includes > 1:
logger.info("Clipping Streamlines by ROI")
select_sl = abu.cut_sls_by_dist(
select_sl, roi_dists,
select_sl = abu.cut_sls_by_closest(
select_sl, roi_closest,
(0, n_includes - 1), in_place=True)

to_flip = bundle_to_flip[select_idx, bundle_idx]
b_def = dict(bundle_dict.get_b_info(bundle_name))
if "bundlesection" in b_def:
for sb_name, sb_include_cuts in bundle_dict.get_b_info(
bundle)["bundlesection"].items():
bundlesection_select_sl = abu.cut_sls_by_dist(
select_sl, roi_dists,
bundlesection_select_sl = abu.cut_sls_by_closest(
select_sl, roi_closest,
sb_include_cuts, in_place=False)
_add_bundle_to_fiber_group(
sb_name, bundlesection_select_sl, select_idx,
Expand Down
32 changes: 15 additions & 17 deletions AFQ/recognition/roi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import numpy as np
from scipy.spatial.distance import cdist
from scipy.ndimage import binary_dilation
from dipy.core.interpolation import interpolate_scalar_3d


def _interp3d(roi, sl):
return interpolate_scalar_3d(roi.get_fdata(), np.asarray(sl))[0]


def check_sls_with_inclusion(sls, include_rois, include_roi_tols):
Expand All @@ -17,15 +20,16 @@ def check_sl_with_inclusion(sl, include_rois,
Helper function to check that a streamline is close to a list of
inclusion ROIS.
"""
dist = []
closest = np.zeros(len(include_rois), dtype=np.int32)
for ii, roi in enumerate(include_rois):
# Use squared Euclidean distance, because it's faster:
dist.append(cdist(sl, roi, 'sqeuclidean'))
if np.min(dist[-1]) > include_roi_tols[ii]:
dist = _interp3d(roi, sl)
closest[ii] = np.argmin(dist)
if dist[closest[ii]] > include_roi_tols[ii]:
# Too far from one of them:
return False, []

# Apparently you checked all the ROIs and it was close to all of them
return True, dist
return True, closest


def check_sl_with_exclusion(sl, exclude_rois,
Expand All @@ -34,8 +38,9 @@ def check_sl_with_exclusion(sl, exclude_rois,
list of exclusion ROIs.
"""
for ii, roi in enumerate(exclude_rois):
# Use squared Euclidean distance, because it's faster:
if np.min(cdist(sl, roi, 'sqeuclidean')) < exclude_roi_tols[ii]:
# if any part of the streamline is near any exclusion ROI,
# return False
if np.any(_interp3d(roi, sl) <= exclude_roi_tols[ii]):
return False
# Either there are no exclusion ROIs, or you are not close to any:
return True
Expand Down Expand Up @@ -78,17 +83,10 @@ def clean_by_endpoints(streamlines, target, target_idx, tol=0,
flip_sls = np.zeros(len(streamlines))
flip_sls = flip_sls.astype(int)

roi = target.get_fdata()
if tol > 0:
roi = binary_dilation(
roi,
iterations=tol)

for ii, sl in enumerate(streamlines):
this_idx = target_idx
if flip_sls[ii]:
this_idx = (len(sl) - this_idx - 1) % len(sl)
xx, yy, zz = sl[this_idx].astype(int)
accepted_idxs[ii] = roi[xx, yy, zz]
accepted_idxs[ii] = _interp3d(target, [sl[this_idx]])[0] <= tol

return accepted_idxs
Loading

0 comments on commit bcbb266

Please sign in to comment.