From 63b83c03399f37f6d167d4e29370c90e6034f7e1 Mon Sep 17 00:00:00 2001 From: 36000 Date: Sat, 27 Jan 2024 11:37:51 -0800 Subject: [PATCH] Seeding on the white matter-gray mattter interface --- AFQ/definitions/image.py | 120 ++++++++++++++++++++++++++++++++++++-- AFQ/tasks/tractography.py | 6 +- AFQ/utils/path.py | 8 +++ AFQ/viz/utils.py | 8 +-- 4 files changed, 130 insertions(+), 12 deletions(-) diff --git a/AFQ/definitions/image.py b/AFQ/definitions/image.py index 95e1f1019..f372c8c33 100644 --- a/AFQ/definitions/image.py +++ b/AFQ/definitions/image.py @@ -4,16 +4,17 @@ import nibabel as nib from dipy.segment.mask import median_otsu +from dipy.segment.tissue import TissueClassifierHMRF from dipy.align import resample -import AFQ.utils.volume as auv from AFQ.definitions.utils import Definition, find_file, name_from_path from skimage.morphology import convex_hull_image, binary_opening -from scipy.linalg import blas +from scipy.ndimage import binary_dilation __all__ = [ - "ImageFile", "FullImage", "RoiImage", "B0Image", "LabelledImageFile", + "ImageFile", "FullImage", "RoiImage", "HMRFImage", "B0Image", + "B0ThreshImage", "LabelledImageFile", "ThresholdedImageFile", "ScalarImage", "ThresholdedScalarImage", "TemplateImage", "GQImage"] @@ -391,6 +392,117 @@ def image_getter_helper(gq_aso): data_imap["gq_aso"]) +class HMRFImage(ImageDefinition): + """ + Use the Hidden Markov Random Field segmentation to generate + a segmentation on white matter, gray matter, and CSF, + from the CSD-derived anisotropic power map. Then, + use this to make a mask of the WM-GM interface. + + Parameters + ---------- + niter : int, optional + Number of iterations to run the HMRF segmentation. + Default: 5 + beta : float, optional + Smoothness factor for the HMRF segmentation. + Default: 0.5 + tp : string, optional + Tissue property derived from diffusion data + which looks similar to T1. Both + "csd_pmap" and "gq_pmap" have been tested. + Default: "csd_pmap" + + Examples + -------- + seed_def = HMRFImage() + api.GroupAFQ(tracking_params={ + "seed_image": seed_def}) + """ + + def __init__(self, beta=0.5, niter=5, tp="csd_pmap"): + self.beta = beta + self.niter = niter + self.tp = tp + + def find_path(self, bids_layout, from_path, subject, session): + pass + + def get_name(self): + return "hmrf" + + def get_image_getter(self, task_name): + if task_name == "data": + raise ValueError("HMRFImage cannot be used in this context") + + def image_getter_helper(data_imap): + brain_mask_img = nib.load(data_imap["brain_mask"]) + hmrf = TissueClassifierHMRF() + pmap_img = nib.load(data_imap[self.tp]) + pmap = pmap_img.get_fdata() + pmap = brain_mask_img.get_fdata() * pmap + logger.info("Generating WM-GM interface mask") + pmap[pmap != 0] = pmap[pmap != 0] - pmap[pmap != 0].min() + _, _, PVE = hmrf.classify( + pmap, 3, self.beta, max_iter=self.niter) + wmgmi = np.logical_and( + PVE[..., 1] > 0.1, + binary_dilation(PVE[..., 2] > 0.1)) + return nib.Nifti1Image( + wmgmi.astype(np.float32), + pmap_img.affine), dict( + source=data_imap[self.tp], + beta=self.beta, + niter=self.niter) + return image_getter_helper + + +class B0ThreshImage(ImageDefinition): + """ + Define an image using thresholding and convex hull on the + b0 data. Note that, for now, this will remove the skull but + keep the eyes. + + Parameters + ---------- + b0_noise_thresh: int, optional + B0 noise threshold for making brain mask from B0. + Default: 1000 + + Examples + -------- + brain_image_definition = B0ThreshImage() + api.GroupAFQ(brain_image_definition=brain_image_definition) + """ + + def __init__(self, b0_noise_thresh=1000): + self.b0_noise_thresh = b0_noise_thresh + + def find_path(self, bids_layout, from_path, subject, session): + pass + + def get_name(self): + return "b0" + + def get_image_getter(self, task_name): + def image_getter_helper(b0): + b0_img = nib.load(b0) + b0_dat = b0_img.get_fdata() + b0_dat = b0_dat[b0_dat > 1000].flatten() + attempt_at_mask = convex_hull_image( + binary_opening( + b0_img.get_fdata() > np.percentile(b0_dat, 10))) + return nib.Nifti1Image( + attempt_at_mask.astype(np.float32), + b0_img.affine), dict( + source=b0, + b0_noise_thresh=self.b0_noise_thresh) + if task_name == "data": + return image_getter_helper + else: + return lambda data_imap: image_getter_helper(data_imap["b0"]) + + class B0Image(ImageDefinition): """ Define an image using b0 and dipy's median_otsu. @@ -414,7 +526,7 @@ def find_path(self, bids_layout, from_path, subject, session): pass def get_name(self): - return "b0" + return "otsu" def get_image_getter(self, task_name): def image_getter_helper(b0): diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index 4f490cbd6..c93136ff5 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -10,7 +10,7 @@ from AFQ.definitions.utils import Definition import AFQ.tractography.tractography as aft from AFQ.tasks.utils import get_default_args -from AFQ.definitions.image import ScalarImage +from AFQ.definitions.image import ScalarImage, HMRFImage try: from trx.trx_file_memmap import TrxFile @@ -290,9 +290,7 @@ def get_tractography_plan(kwargs): kwargs["tracking_params"]["odf_model"] =\ kwargs["tracking_params"]["odf_model"].upper() if kwargs["tracking_params"]["seed_mask"] is None: - kwargs["tracking_params"]["seed_mask"] = ScalarImage( - kwargs["best_scalar"]) - kwargs["tracking_params"]["seed_threshold"] = 0.2 + kwargs["tracking_params"]["seed_mask"] = HMRFImage() if kwargs["tracking_params"]["stop_mask"] is None: kwargs["tracking_params"]["stop_mask"] = ScalarImage( kwargs["best_scalar"]) diff --git a/AFQ/utils/path.py b/AFQ/utils/path.py index fbfd7f649..b4152dfe6 100644 --- a/AFQ/utils/path.py +++ b/AFQ/utils/path.py @@ -1,6 +1,10 @@ import os.path as op import os import json +import logging + + +logger = logging.getLogger('AFQ') def write_json(fname, data): @@ -62,6 +66,10 @@ def apply_cmd_to_afq_derivs( "dependent_on must be one of " "None, 'track', 'recog', 'prof'.")) + if not op.exists(derivs_dir): + logger.warning(f"Nothing to {cmd} in {derivs_dir}") + return + for filename in os.listdir(derivs_dir): full_path = os.path.join(derivs_dir, filename) if os.path.isfile(full_path) or os.path.islink(full_path): diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index dd9750e93..445fd6397 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -56,9 +56,9 @@ "Right Cingulum Cingulate": tableau_20[5], "CCMid": tableau_20[5], "Forceps Minor": tableau_20[8], "CC_ForcepsMinor": tableau_20[8], "Forceps Major": tableau_20[9], "CC_ForcepsMajor": tableau_20[9], - "Left Inferior Fronto-Occipital": tableau_20[10], + "Left Inferior Fronto-occipital": tableau_20[10], "IFOF_L": tableau_20[10], - "Right Inferior Fronto-Occipital": tableau_20[11], + "Right Inferior Fronto-occipital": tableau_20[11], "IFOF_R": tableau_20[11], "Left Inferior Longitudinal": tableau_20[12], "F_L": tableau_20[12], "Right Inferior Longitudinal": tableau_20[13], "F_R": tableau_20[13], @@ -88,8 +88,8 @@ "MCP": (3, 1), "CCMid": (3, 3), "Forceps Minor": (4, 2), "Forceps Major": (0, 2), "CC_ForcepsMinor": (4, 2), "CC_ForcepsMajor": (0, 2), - "Left Inferior Fronto-Occipital": (4, 1), - "Right Inferior Fronto-Occipital": (4, 3), + "Left Inferior Fronto-occipital": (4, 1), + "Right Inferior Fronto-occipital": (4, 3), "IFOF_L": (4, 1), "IFOF_R": (4, 3), "Left Inferior Longitudinal": (3, 0), "Right Inferior Longitudinal": (3, 4),