From 1c664e89bfbca90784fd09f06dad37a7b9aa05b5 Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 5 Dec 2024 12:48:23 -0800 Subject: [PATCH] BF and allow legacy change --- AFQ/tasks/tractography.py | 4 +--- AFQ/tasks/utils.py | 3 ++- AFQ/tractography/tractography.py | 11 +++++++---- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index f318d43d..8d987817 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -260,11 +260,9 @@ def fiber_odf(data_imap, tracking_params): odf_model = tracking_params["odf_model"] if isinstance(odf_model, str): params_file = data_imap[f"{odf_model.lower()}_params"] - elif isinstance(odf_model, nib.Nifti1Image): - params_file = odf_model else: raise TypeError(( - "odf_model must be a str, nibabel image, or Definition")) + "odf_model must be a string or Definition")) return params_file diff --git a/AFQ/tasks/utils.py b/AFQ/tasks/utils.py index a0db74d4..8728c9c8 100644 --- a/AFQ/tasks/utils.py +++ b/AFQ/tasks/utils.py @@ -1,6 +1,5 @@ from AFQ.utils.path import drop_extension import os.path as op -import inspect __all__ = ["get_fname", "with_name", "get_base_fname"] @@ -27,6 +26,8 @@ def get_fname(base_fname, suffix, fname = base_fname if tracking_params is not None and 'odf_model' in tracking_params: odf_model = tracking_params['odf_model'] + if not isinstance(odf_model, str): + odf_model = odf_model.get_name() directions = tracking_params['directions'] fname = fname + ( f'_coordsys-RASMM_trkmethod-{directions+odf_model}' diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index e2ddafd3..9c5bb03a 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -25,8 +25,8 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, seed_mask=None, seed_threshold=0, thresholds_as_percentages=False, n_seeds=1, random_seeds=False, rng_seed=None, stop_mask=None, stop_threshold=0, step_size=0.5, minlen=50, maxlen=250, - odf_model="CSD", basis_type="descoteaux07", tracker="local", - trx=False): + odf_model="CSD", basis_type="descoteaux07", legacy=True, + tracker="local", trx=False): """ Tractography @@ -97,6 +97,9 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, basis_type : str, optional The spherical harmonic basis type used to represent the coefficients. One of {"descoteaux07", "tournier07"}. Deafult: "descoteaux07" + legacy : bool, optional + Whether to use the legacy implementation of the direction getter. + See Dipy documentation for more details. Default: True tracker : str, optional Which strategy to use in tracking. This can be the standard local tracking ("local") or Particle Filtering Tracking ([Girard2014]_). @@ -155,10 +158,10 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, evals = model_params[..., :3] evecs = model_params[..., 3:12].reshape(params_img.shape[:3] + (3, 3)) odf = tensor_odf(evals, evecs, sphere) - dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere) + dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere, legacy=legacy) else: dg = dg.from_shcoeff(model_params, max_angle=max_angle, sphere=sphere, - basis_type=basis_type, legacy=False) + basis_type=basis_type, legacy=legacy) if tracker == "local": if stop_mask is None: