Skip to content

Commit

Permalink
BF/twaks
Browse files Browse the repository at this point in the history
  • Loading branch information
36000 committed Dec 5, 2024
1 parent e98128b commit 07d7f3e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
10 changes: 5 additions & 5 deletions AFQ/tasks/tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def custom_tractography(import_tract=None):

@pimms.calc("streamlines")
@as_file('_tractography', include_track=True)
def gpu_tractography(data_imap, tracking_params, seed, stop,
def gpu_tractography(data_imap, fodf, tracking_params, seed, stop,
tractography_ngpus=0, chunk_size=100000):
"""
full path to the complete, unsegmented tractography file
Expand All @@ -313,8 +313,7 @@ def gpu_tractography(data_imap, tracking_params, seed, stop,
if tracking_params["directions"] == "boot":
data = data_imap["data"]
else:
data = nib.load(
data_imap[tracking_params["odf_model"].lower() + "_params"]).get_fdata()
data = fodf.get_fdata()

sphere = tracking_params["sphere"]
if sphere is None:
Expand Down Expand Up @@ -392,8 +391,9 @@ def get_tractography_plan(kwargs):
default_tracking_params[k] = kwargs["tracking_params"][k]

kwargs["tracking_params"] = default_tracking_params
kwargs["tracking_params"]["odf_model"] =\
kwargs["tracking_params"]["odf_model"].upper()
if isinstance(kwargs["tracking_params"]["odf_model"], str):
kwargs["tracking_params"]["odf_model"] =\
kwargs["tracking_params"]["odf_model"].upper()
if kwargs["tracking_params"]["seed_mask"] is None:
kwargs["tracking_params"]["seed_mask"] = ScalarImage(
kwargs["best_scalar"])
Expand Down
6 changes: 4 additions & 2 deletions AFQ/tractography/tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def track(params_file, directions="prob", max_angle=30., sphere=None,
seed_mask=None, seed_threshold=0, thresholds_as_percentages=False,
n_seeds=1, random_seeds=False, rng_seed=None, stop_mask=None,
stop_threshold=0, step_size=0.5, minlen=50, maxlen=250,
odf_model="CSD", basis_type="descoteaux07", tracker="local",
odf_model="CSD", basis_type="descoteaux07", tracker="local",
trx=False):
"""
Tractography
Expand Down Expand Up @@ -149,13 +149,15 @@ def track(params_file, directions="prob", max_angle=30., sphere=None,
else:
raise ValueError(f"Unrecognized direction '{directions}'.")

logger.debug(f"Using basis type: {basis_type}")

if odf_model == "DTI" or odf_model == "DKI":
evals = model_params[..., :3]
evecs = model_params[..., 3:12].reshape(params_img.shape[:3] + (3, 3))
odf = tensor_odf(evals, evecs, sphere)
dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere)
else:
dg = dg.from_shcoeff(model_params, max_angle=max_angle, sphere=sphere,
dg = dg.from_shcoeff(model_params, max_angle=max_angle, sphere=sphere,
basis_type=basis_type, legacy=False)

if tracker == "local":
Expand Down

0 comments on commit 07d7f3e

Please sign in to comment.