From 485e17182e582242ade0177b9262a555926a77ca Mon Sep 17 00:00:00 2001 From: 36000 Date: Thu, 19 Dec 2024 17:42:51 -0800 Subject: [PATCH] more tweaks --- AFQ/api/bundle_dict.py | 15 +++++++++++++-- AFQ/tasks/data.py | 4 ++-- AFQ/tasks/mapping.py | 4 ++-- AFQ/tasks/segmentation.py | 17 ++++++++++++----- AFQ/tasks/tractography.py | 8 ++++---- AFQ/tasks/utils.py | 5 +---- AFQ/tasks/viz.py | 2 +- AFQ/utils/path.py | 3 +++ 8 files changed, 38 insertions(+), 20 deletions(-) diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index 38c7fe4a..25064476 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -1050,14 +1050,25 @@ def transform_rois(self, bundle_name, mapping, new_affine, if base_fname is not None: fnames = [] for roi_type, rois in transformed_rois.items(): + if roi_type == "prob_map": + suffix = "probseg" + else: + suffix = "mask" + roi_type_name = roi_type.lower().replace( + " ", "").replace( + "_", "").replace( + "-", "") if not isinstance(rois, list): rois = [rois] for ii, roi in enumerate(rois): + suffix = f"{str_to_desc(bundle_name)}{roi_type_name}" + if roi_type in ["include", "exclude"]: + suffix = f"{suffix}{ii}" fname = get_fname( base_fname, "_space-subject_desc-" - f"{str_to_desc(bundle_name)}{roi_type}{ii}" - "_mask.nii.gz", + f"{suffix}" + f"_{suffix}.nii.gz", "ROIs") nib.save( nib.Nifti1Image( diff --git a/AFQ/tasks/data.py b/AFQ/tasks/data.py index b184b903..ff462464 100644 --- a/AFQ/tasks/data.py +++ b/AFQ/tasks/data.py @@ -93,7 +93,7 @@ def get_data_gtab(dwi_data_file, bval_file, bvec_file, min_bval=None, @pimms.calc("b0") -@as_file('_desc-b0_dwimap.nii.gz') +@as_file('_b0ref.nii.gz') def b0(dwi, gtab): """ full path to a nifti file containing the mean b0 @@ -105,7 +105,7 @@ def b0(dwi, gtab): @pimms.calc("masked_b0") -@as_file('_desc-maskedb0_dwimap.nii.gz') +@as_file('_desc-masked_b0ref.nii.gz') def b0_mask(b0, brain_mask): """ full path to a nifti file containing the diff --git a/AFQ/tasks/mapping.py b/AFQ/tasks/mapping.py index 726ab2af..2660c989 100644 --- a/AFQ/tasks/mapping.py +++ b/AFQ/tasks/mapping.py @@ -28,7 +28,7 @@ def export_registered_b0(base_fname, data_imap, mapping): """ warped_b0_fname = get_fname( base_fname, - f'_space-{data_imap["tmpl_name"]}_desc-b0_dwimap.nii.gz') + f'_space-{data_imap["tmpl_name"]}_b0ref.nii.gz') if not op.exists(warped_b0_fname): mean_b0 = nib.load(data_imap["b0"]).get_fdata() warped_b0 = mapping.transform(mean_b0) @@ -40,7 +40,7 @@ def export_registered_b0(base_fname, data_imap, mapping): dependent="dwi") meta_fname = get_fname( base_fname, - f'_space-{data_imap["tmpl_name"]}_desc-b0_dwimap.json') + f'_space-{data_imap["tmpl_name"]}_b0ref.json') write_json(meta_fname, meta) return warped_b0_fname diff --git a/AFQ/tasks/segmentation.py b/AFQ/tasks/segmentation.py index 2b941794..cca2c1f9 100644 --- a/AFQ/tasks/segmentation.py +++ b/AFQ/tasks/segmentation.py @@ -39,7 +39,7 @@ @pimms.calc("bundles") -@as_file('desc-bundles_tractography', +@as_file('_desc-bundles_tractography', include_track=True, include_seg=True) def segment(data_imap, mapping_imap, @@ -69,18 +69,18 @@ def segment(data_imap, mapping_imap, tg = trx.to_sft() elif streamlines.endswith(".tck.gz"): # uncompress tck.gz to a temporary tck: - temp_tck = op.join(mkdtemp(), op.split(streamlines.replace(".gz", ""))[1]) + temp_tck = op.join(mkdtemp(), op.split( + streamlines.replace(".gz", ""))[1]) logger.info(f"Temporary tck file created at: {temp_tck}") with gzip.open(streamlines, 'rb') as f_in: with open(temp_tck, 'wb') as f_out: shutil.copyfileobj(f_in, f_out) # initialize stateful tractogram from tck file: tg = load_tractogram( - temp_tck, data_imap["dwi"], Space.VOX, + temp_tck, data_imap["dwi"], Space.VOX, bbox_valid_check=False) is_trx = False - indices_to_remove, _ = tg.remove_invalid_streamlines() if len(indices_to_remove) > 0: logger.warning(f"{len(indices_to_remove)} invalid streamlines removed") @@ -335,7 +335,14 @@ def _median_weight(bundle): this_prof_weights = _median_weight else: this_prof_weights = profile_weights - this_prof_weights[np.isnan(this_prof_weights)] = 0 + if np.any(np.isnan(this_prof_weights)): # fit failed + logger.warning(( + f"Even weighting used for " + f"bundle {bundle_name}, scalar {scalar} " + f"in profiling due inability to estimate weights. " + "This is often caused by low streamline count or " + "low variance in the scalar data.")) + this_prof_weights = np.ones_like(this_prof_weights) this_profile[ii] = afq_profile( scalar_data, this_sl, diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index 89ee160b..39248288 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -92,8 +92,8 @@ def export_seed_mask_thresholded(data_imap, seed, tracking_params): tractography seed mask thresholded """ thresh = tracking_params['seed_threshold'] - threshed_data = seed.get_fdata() > thresh - seed_mask_desc = dict(source=tracking_params['seed_mask'], thresh=thresh) + threshed_data = nib.load(seed).get_fdata() > thresh + seed_mask_desc = dict(source=seed, thresh=thresh) return nib.Nifti1Image( threshed_data.astype(np.float32), data_imap["dwi_affine"]), seed_mask_desc @@ -130,8 +130,8 @@ def export_stop_mask_thresholded(data_imap, stop, tracking_params): tractography stop mask thresholded """ thresh = tracking_params['stop_threshold'] - threshed_data = stop.get_fdata() > thresh - stop_mask_desc = dict(source=tracking_params['stop_mask'], thresh=thresh) + threshed_data = nib.load(stop).get_fdata() > thresh + stop_mask_desc = dict(source=stop, thresh=thresh) return nib.Nifti1Image( threshed_data.astype(np.float32), data_imap["dwi_affine"]), stop_mask_desc diff --git a/AFQ/tasks/utils.py b/AFQ/tasks/utils.py index 5c87ef8f..73e45e58 100644 --- a/AFQ/tasks/utils.py +++ b/AFQ/tasks/utils.py @@ -46,10 +46,7 @@ def get_fname(base_fname, suffix, subfolder=None): if folders[-1] == "dwi": if len(folders) > 1 and "sub-" in folders[-2] or \ "ses-" in folders[-2]: - if "sub-" in folders[-2] or len(folders) == 2: - return op.join(*folders[:-2], base_fname + suffix) - else: - return op.join(*folders[:-3], base_fname + suffix) + return op.join(*folders[:-2], folders[-2] + suffix) else: return op.join(*folders[:-1], base_fname + suffix) else: diff --git a/AFQ/tasks/viz.py b/AFQ/tasks/viz.py index e208fda5..1c8f2623 100644 --- a/AFQ/tasks/viz.py +++ b/AFQ/tasks/viz.py @@ -336,7 +336,7 @@ def plot_tract_profiles(base_fname, output_dir, scalars, segmentation_imap): this_scalar = scalar if isinstance(scalar, str) else scalar.get_name() fname = get_fname( base_fname, - f'_param-{str_to_desc(this_scalar)}_desc-vizprofile_dwimap.png', + f'_param-{str_to_desc(this_scalar)}_desc-vizprofile_tractography.png', 'tract_profile_plots') visualize_tract_profiles( diff --git a/AFQ/utils/path.py b/AFQ/utils/path.py index baf003d3..fdc439ca 100644 --- a/AFQ/utils/path.py +++ b/AFQ/utils/path.py @@ -98,3 +98,6 @@ def apply_cmd_to_afq_derivs( if filename == "viz_core_bundles" and \ "prof" in dependent_on_list: os.system(f"{cmd} -r {full_path} {suffix}") + if filename == "tract_profile_plots" and \ + "prof" in dependent_on_list: + os.system(f"{cmd} -r {full_path} {suffix}")