Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Seeding on the white matter-gray mattter interface #1096

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 116 additions & 4 deletions AFQ/definitions/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
import nibabel as nib

from dipy.segment.mask import median_otsu
from dipy.segment.tissue import TissueClassifierHMRF
from dipy.align import resample

import AFQ.utils.volume as auv
from AFQ.definitions.utils import Definition, find_file, name_from_path

from skimage.morphology import convex_hull_image, binary_opening
from scipy.linalg import blas
from scipy.ndimage import binary_dilation

__all__ = [
"ImageFile", "FullImage", "RoiImage", "B0Image", "LabelledImageFile",
"ImageFile", "FullImage", "RoiImage", "HMRFImage", "B0Image",
"B0ThreshImage", "LabelledImageFile",
"ThresholdedImageFile", "ScalarImage", "ThresholdedScalarImage",
"TemplateImage", "GQImage"]

Expand Down Expand Up @@ -391,6 +392,117 @@ def image_getter_helper(gq_aso):
data_imap["gq_aso"])


class HMRFImage(ImageDefinition):
"""
Use the Hidden Markov Random Field segmentation to generate
a segmentation on white matter, gray matter, and CSF,
from the CSD-derived anisotropic power map. Then,
use this to make a mask of the WM-GM interface.

Parameters
----------
niter : int, optional
Number of iterations to run the HMRF segmentation.
Default: 5
beta : float, optional
Smoothness factor for the HMRF segmentation.
Default: 0.5
tp : string, optional
Tissue property derived from diffusion data
which looks similar to T1. Both
"csd_pmap" and "gq_pmap" have been tested.
Default: "csd_pmap"

Examples
--------
seed_def = HMRFImage()
api.GroupAFQ(tracking_params={
"seed_image": seed_def})
"""

def __init__(self, beta=0.5, niter=5, tp="csd_pmap"):
self.beta = beta
self.niter = niter
self.tp = tp

def find_path(self, bids_layout, from_path, subject, session):
pass

def get_name(self):
return "hmrf"

def get_image_getter(self, task_name):
if task_name == "data":
raise ValueError("HMRFImage cannot be used in this context")

def image_getter_helper(data_imap):
brain_mask_img = nib.load(data_imap["brain_mask"])
hmrf = TissueClassifierHMRF()
pmap_img = nib.load(data_imap[self.tp])
pmap = pmap_img.get_fdata()
pmap = brain_mask_img.get_fdata() * pmap
logger.info("Generating WM-GM interface mask")
pmap[pmap != 0] = pmap[pmap != 0] - pmap[pmap != 0].min()
_, _, PVE = hmrf.classify(
pmap, 3, self.beta, max_iter=self.niter)
wmgmi = np.logical_and(
PVE[..., 1] > 0.1,
binary_dilation(PVE[..., 2] > 0.1))
return nib.Nifti1Image(
wmgmi.astype(np.float32),
pmap_img.affine), dict(
source=data_imap[self.tp],
beta=self.beta,
niter=self.niter)
return image_getter_helper


class B0ThreshImage(ImageDefinition):
"""
Define an image using thresholding and convex hull on the
b0 data. Note that, for now, this will remove the skull but
keep the eyes.

Parameters
----------
b0_noise_thresh: int, optional
B0 noise threshold for making brain mask from B0.
Default: 1000

Examples
--------
brain_image_definition = B0ThreshImage()
api.GroupAFQ(brain_image_definition=brain_image_definition)
"""

def __init__(self, b0_noise_thresh=1000):
self.b0_noise_thresh = b0_noise_thresh

def find_path(self, bids_layout, from_path, subject, session):
pass

def get_name(self):
return "b0"

def get_image_getter(self, task_name):
def image_getter_helper(b0):
b0_img = nib.load(b0)
b0_dat = b0_img.get_fdata()
b0_dat = b0_dat[b0_dat > 1000].flatten()
attempt_at_mask = convex_hull_image(
binary_opening(
b0_img.get_fdata() > np.percentile(b0_dat, 10)))
return nib.Nifti1Image(
attempt_at_mask.astype(np.float32),
b0_img.affine), dict(
source=b0,
b0_noise_thresh=self.b0_noise_thresh)
if task_name == "data":
return image_getter_helper
else:
return lambda data_imap: image_getter_helper(data_imap["b0"])


class B0Image(ImageDefinition):
"""
Define an image using b0 and dipy's median_otsu.
Expand All @@ -414,7 +526,7 @@ def find_path(self, bids_layout, from_path, subject, session):
pass

def get_name(self):
return "b0"
return "otsu"

def get_image_getter(self, task_name):
def image_getter_helper(b0):
Expand Down
6 changes: 2 additions & 4 deletions AFQ/tasks/tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from AFQ.definitions.utils import Definition
import AFQ.tractography.tractography as aft
from AFQ.tasks.utils import get_default_args
from AFQ.definitions.image import ScalarImage
from AFQ.definitions.image import ScalarImage, HMRFImage

try:
from trx.trx_file_memmap import TrxFile
Expand Down Expand Up @@ -290,9 +290,7 @@ def get_tractography_plan(kwargs):
kwargs["tracking_params"]["odf_model"] =\
kwargs["tracking_params"]["odf_model"].upper()
if kwargs["tracking_params"]["seed_mask"] is None:
kwargs["tracking_params"]["seed_mask"] = ScalarImage(
kwargs["best_scalar"])
kwargs["tracking_params"]["seed_threshold"] = 0.2
kwargs["tracking_params"]["seed_mask"] = HMRFImage()
if kwargs["tracking_params"]["stop_mask"] is None:
kwargs["tracking_params"]["stop_mask"] = ScalarImage(
kwargs["best_scalar"])
Expand Down
8 changes: 8 additions & 0 deletions AFQ/utils/path.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import os.path as op
import os
import json
import logging


logger = logging.getLogger('AFQ')


def write_json(fname, data):
Expand Down Expand Up @@ -62,6 +66,10 @@ def apply_cmd_to_afq_derivs(
"dependent_on must be one of "
"None, 'track', 'recog', 'prof'."))

if not op.exists(derivs_dir):
logger.warning(f"Nothing to {cmd} in {derivs_dir}")
return

for filename in os.listdir(derivs_dir):
full_path = os.path.join(derivs_dir, filename)
if os.path.isfile(full_path) or os.path.islink(full_path):
Expand Down
8 changes: 4 additions & 4 deletions AFQ/viz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@
"Right Cingulum Cingulate": tableau_20[5], "CCMid": tableau_20[5],
"Forceps Minor": tableau_20[8], "CC_ForcepsMinor": tableau_20[8],
"Forceps Major": tableau_20[9], "CC_ForcepsMajor": tableau_20[9],
"Left Inferior Fronto-Occipital": tableau_20[10],
"Left Inferior Fronto-occipital": tableau_20[10],
"IFOF_L": tableau_20[10],
"Right Inferior Fronto-Occipital": tableau_20[11],
"Right Inferior Fronto-occipital": tableau_20[11],
"IFOF_R": tableau_20[11],
"Left Inferior Longitudinal": tableau_20[12], "F_L": tableau_20[12],
"Right Inferior Longitudinal": tableau_20[13], "F_R": tableau_20[13],
Expand Down Expand Up @@ -88,8 +88,8 @@
"MCP": (3, 1), "CCMid": (3, 3),
"Forceps Minor": (4, 2), "Forceps Major": (0, 2),
"CC_ForcepsMinor": (4, 2), "CC_ForcepsMajor": (0, 2),
"Left Inferior Fronto-Occipital": (4, 1),
"Right Inferior Fronto-Occipital": (4, 3),
"Left Inferior Fronto-occipital": (4, 1),
"Right Inferior Fronto-occipital": (4, 3),
"IFOF_L": (4, 1), "IFOF_R": (4, 3),
"Left Inferior Longitudinal": (3, 0),
"Right Inferior Longitudinal": (3, 4),
Expand Down
Loading