Skip to content

Commit

Permalink
Merge pull request #20 from 36000/update_gpu
Browse files Browse the repository at this point in the history
Update gpu code
  • Loading branch information
arokem authored Nov 14, 2024
2 parents 228d77b + 05dc672 commit 36db25c
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 172 deletions.
27 changes: 13 additions & 14 deletions AFQ/recognition/cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,7 @@ def clean_bundle(tg, n_points=100, clean_rounds=5, distance_threshold=3,
# We'll only do this for clean_rounds
rounds_elapsed = 0
idx_belong = idx
while (rounds_elapsed < clean_rounds) and (np.sum(idx_belong) > min_sl):
# Update by selection:
idx = idx[idx_belong]
fgarray = fgarray[idx_belong]
lengths = lengths[idx_belong]
rounds_elapsed += 1

while rounds_elapsed < clean_rounds:
# This calculates the Mahalanobis for each streamline/node:
m_dist = gaussian_weights(
fgarray, return_mahalnobis=True,
Expand All @@ -150,8 +144,8 @@ def clean_bundle(tg, n_points=100, clean_rounds=5, distance_threshold=3,
f"{length_z}"))

if not (
np.any(m_dist > distance_threshold)
or np.any(length_z > length_threshold)):
np.any(m_dist >= distance_threshold)
or np.any(length_z >= length_threshold)):
break
# Select the fibers that have Mahalanobis smaller than the
# threshold for all their nodes:
Expand All @@ -161,17 +155,22 @@ def clean_bundle(tg, n_points=100, clean_rounds=5, distance_threshold=3,

if np.sum(idx_belong) < min_sl:
# need to sort and return exactly min_sl:
idx_belong = np.argsort(np.sum(
m_dist, axis=-1))[:min_sl].astype(int)
idx = idx[np.argsort(np.sum(
m_dist, axis=-1))[:min_sl].astype(int)]
logger.debug((
f"At rounds elapsed {rounds_elapsed}, "
"minimum streamlines reached"))
break
else:
idx_removed = idx_belong == 0
# Update by selection:
idx = idx[idx_belong]
fgarray = fgarray[idx_belong]
lengths = lengths[idx_belong]
rounds_elapsed += 1
logger.debug((
f"Rounds elapsed: {rounds_elapsed}, "
f"num removed: {np.sum(idx_removed)}"))
logger.debug(f"Removed indicies: {np.where(idx_removed)[0]}")
f"num kept: {len(idx)}"))
logger.debug(f"Kept indicies: {idx}")

# Select based on the variable that was keeping track of things for us:
if hasattr(tg, "streamlines"):
Expand Down
2 changes: 1 addition & 1 deletion AFQ/recognition/recognize.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def recognize(
_add_bundle_to_fiber_group(
sb_name, bundlesection_select_sl, select_idx,
to_flip, return_idx, fiber_groups, img)
_add_bundle_to_meta(sb_name, b_def)
_add_bundle_to_meta(sb_name, b_def, meta)
else:
_add_bundle_to_fiber_group(
bundle, select_sl, select_idx, to_flip,
Expand Down
4 changes: 4 additions & 0 deletions AFQ/tasks/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ def get_segmentation_plan(kwargs):
and not isinstance(kwargs["segmentation_params"], dict):
raise TypeError(
"segmentation_params a dict")
if "cleaning_params" in kwargs:
raise ValueError(
"cleaning_params should be passed inside of"
"segmentation_params")
segmentation_tasks = with_name([
get_scalar_dict,
export_sl_counts,
Expand Down
61 changes: 36 additions & 25 deletions AFQ/tasks/tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from time import time
import logging

import dipy.data as dpd

import pimms
import multiprocessing

Expand All @@ -12,19 +14,16 @@
import AFQ.tractography.tractography as aft
from AFQ.tasks.utils import get_default_args
from AFQ.definitions.image import ScalarImage
from AFQ.tractography.utils import gen_seeds
from AFQ.tractography.utils import gen_seeds, get_percentile_threshold

from trx.trx_file_memmap import TrxFile
from trx.trx_file_memmap import concatenate as trx_concatenate

try:
import ray
has_ray = True
except ModuleNotFoundError:
has_ray = False
try:
from trx.trx_file_memmap import TrxFile
from trx.trx_file_memmap import concatenate as trx_concatenate
has_trx = True
except ModuleNotFoundError:
has_trx = False

try:
from AFQ.tractography.gputractography import gpu_track
Expand Down Expand Up @@ -70,7 +69,13 @@ def export_seed_mask(data_imap, tracking_params):
tractography seed mask
"""
seed_mask = tracking_params['seed_mask']
seed_mask_desc = dict(source=tracking_params['seed_mask'])
seed_threshold = tracking_params['seed_threshold']
if tracking_params['thresholds_as_percentages']:
seed_threshold = get_percentile_threshold(
seed_mask, seed_threshold)
seed_mask_desc = dict(
source=tracking_params['seed_mask'],
threshold=seed_threshold)
return nib.Nifti1Image(seed_mask, data_imap["dwi_affine"]), \
seed_mask_desc

Expand All @@ -83,7 +88,13 @@ def export_stop_mask(data_imap, tracking_params):
tractography stop mask
"""
stop_mask = tracking_params['stop_mask']
stop_mask_desc = dict(source=tracking_params['stop_mask'])
stop_threshold = tracking_params['stop_threshold']
if tracking_params['thresholds_as_percentages']:
stop_threshold = get_percentile_threshold(
stop_mask, stop_threshold)
stop_mask_desc = dict(
source=tracking_params['stop_mask'],
stop_threshold=stop_threshold)
return nib.Nifti1Image(stop_mask, data_imap["dwi_affine"]), \
stop_mask_desc

Expand Down Expand Up @@ -290,23 +301,38 @@ def gpu_tractography(data_imap, tracking_params, seed, stop,
Number of GPUs to use in tractography. If non-0,
this algorithm is used for tractography,
https://github.com/dipy/GPUStreamlines
PTT, Prob can be used with any SHM model.
Bootstrapped can be done with CSA/OPDT.
Default: 0
chunk_size : int, optional
Chunk size for GPU tracking.
Default: 100000
"""
start_time = time()
if tracking_params["directions"] == "boot":
data = data_imap["data"]
else:
data = nib.load(
data_imap[tracking_params["odf_model"].lower() + "_params"]).get_fdata()

sphere = tracking_params["sphere"]
if sphere is None:
sphere = dpd.default_sphere

sft = gpu_track(
data_imap["data"], data_imap["gtab"],
data, data_imap["gtab"],
nib.load(seed), nib.load(stop),
tracking_params["odf_model"],
sphere,
tracking_params["directions"],
tracking_params["seed_threshold"],
tracking_params["stop_threshold"],
tracking_params["thresholds_as_percentages"],
tracking_params["max_angle"], tracking_params["step_size"],
tracking_params["n_seeds"],
tracking_params["random_seeds"],
tracking_params["rng_seed"],
tracking_params["trx"],
tractography_ngpus,
chunk_size)

Expand Down Expand Up @@ -403,18 +429,3 @@ def get_tractography_plan(kwargs):
seed_mask.get_image_getter("tractography")))

return pimms.plan(**tractography_tasks)


def _gen_seeds(n_seeds, params_file, seed_mask=None, seed_threshold=0,
thresholds_as_percentages=False,
random_seeds=False, rng_seed=None):
if isinstance(params_file, str):
params_img = nib.load(params_file)
else:
params_img = params_file

affine = params_img.affine

return gen_seeds(seed_mask, seed_threshold, n_seeds,
thresholds_as_percentages,
random_seeds, rng_seed, affine)
Loading

0 comments on commit 36db25c

Please sign in to comment.