Skip to content

Commit

Permalink
stitch in custom FODFs
Browse files Browse the repository at this point in the history
  • Loading branch information
36000 committed Dec 5, 2024
1 parent 17816fb commit e98128b
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions AFQ/tasks/tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def export_stop_mask_pft(pve_wm, pve_gm, pve_csf):

@pimms.calc("streamlines")
@as_file('_tractography', include_track=True)
def streamlines(data_imap, seed, stop,
def streamlines(data_imap, seed, stop, fodf,
tracking_params):
"""
full path to the complete, unsegmented tractography file
Expand All @@ -128,22 +128,6 @@ def streamlines(data_imap, seed, stop,
"""
this_tracking_params = tracking_params.copy()

# get odf_model
odf_model = this_tracking_params["odf_model"]
if odf_model == "DTI":
params_file = data_imap["dti_params"]
elif odf_model == "CSD":
params_file = data_imap["csd_params"]
elif odf_model == "DKI":
params_file = data_imap["dki_params"]
elif odf_model == "GQ":
params_file = data_imap["gq_params"]
elif odf_model == "RUMBA":
params_file = data_imap["rumba_params"]
else:
raise TypeError((
f"The ODF model you gave ({odf_model}) was not recognized"))

# get masks
this_tracking_params['seed_mask'] = nib.load(seed).get_fdata()
if isinstance(stop, str):
Expand Down Expand Up @@ -233,7 +217,7 @@ def delete_lazyt(self, id):
tracking_params_list[i]['n_seeds'] = seed_chunks[i]

# create lazyt inside each actor
tasks = [ray_actor.create_lazyt.remote(object_id, params_file,
tasks = [ray_actor.create_lazyt.remote(object_id, fodf,
**tracking_params_list[i]) for i, ray_actor in
enumerate(actors)]
ray.get(tasks)
Expand All @@ -250,7 +234,7 @@ def delete_lazyt(self, id):

sft = trx_concatenate(sfts)
else:
lazyt = aft.track(params_file, **this_tracking_params)
lazyt = aft.track(fodf, **this_tracking_params)
sft = TrxFile.from_lazy_tractogram(
lazyt,
seed,
Expand All @@ -259,7 +243,7 @@ def delete_lazyt(self, id):

else:
start_time = time()
sft = aft.track(params_file, **this_tracking_params)
sft = aft.track(fodf, **this_tracking_params)
sft.to_vox()
n_streamlines = len(sft.streamlines)

Expand All @@ -268,6 +252,23 @@ def delete_lazyt(self, id):
n_streamlines, seed, stop)


@pimms.calc("fodf")
def fiber_odf(data_imap, tracking_params):
"""
Nifti Image containing the fiber orientation distribution function
"""
odf_model = tracking_params["odf_model"]
if isinstance(odf_model, str):
params_file = data_imap[f"{odf_model.lower()}_params"]
elif isinstance(odf_model, nib.Nifti1Image):
params_file = odf_model
else:
raise TypeError((
"odf_model must be a str, nibabel image, or Definition"))

return params_file


@pimms.calc("streamlines")
def custom_tractography(import_tract=None):
"""
Expand Down Expand Up @@ -348,7 +349,7 @@ def get_tractography_plan(kwargs):
"tracking_params a dict")

tractography_tasks = with_name([
export_seed_mask, export_stop_mask, streamlines])
export_seed_mask, export_stop_mask, streamlines, fiber_odf])

# use GPU accelerated tractography if asked for
if "tractography_ngpus" in kwargs and kwargs["tractography_ngpus"] != 0:
Expand Down Expand Up @@ -410,6 +411,7 @@ def get_tractography_plan(kwargs):

stop_mask = kwargs["tracking_params"]['stop_mask']
seed_mask = kwargs["tracking_params"]['seed_mask']
odf_model = kwargs["tracking_params"]['odf_model']

if kwargs["tracking_params"]["tracker"] == "pft":
probseg_funcs = stop_mask.get_image_getter("tractography")
Expand All @@ -428,4 +430,8 @@ def get_tractography_plan(kwargs):
as_file('_desc-seed_mask.nii.gz', include_track=True)(
seed_mask.get_image_getter("tractography")))

if isinstance(odf_model, Definition):
tractography_tasks["fiber_odf_res"] = pimms.calc("fodf")(
odf_model.get_image_getter("tractography"))

return pimms.plan(**tractography_tasks)

0 comments on commit e98128b

Please sign in to comment.