Skip to content

Commit

Permalink
Merge pull request #419 from mgxd/enh/spatial-norm
Browse files Browse the repository at this point in the history
FEAT: Option to normalize CSF prior to template registration
  • Loading branch information
mgxd authored Dec 13, 2024
2 parents 9476035 + 8981735 commit b9efa6f
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 26 deletions.
5 changes: 5 additions & 0 deletions nibabies/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,11 @@ def _str_none(val):
default=16,
help='Frame to start head motion estimation on BOLD.',
)
g_baby.add_argument(
'--norm-csf',
action='store_true',
help='Replace low intensity voxels in CSF mask with average',
)
return parser


Expand Down
2 changes: 2 additions & 0 deletions nibabies/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,8 @@ class workflow(_Config):
"""Run FreeSurfer ``recon-all`` with the ``-logitudinal`` flag."""
medial_surface_nan = None
"""Fill medial surface with :abbr:`NaNs (not-a-number)` when sampling."""
norm_csf = False
"""Replace low intensity voxels in CSF mask with average."""
project_goodvoxels = False
"""Exclude voxels with locally high coefficient of variation from sampling."""
regressors_all_comps = None
Expand Down
70 changes: 45 additions & 25 deletions nibabies/workflows/anatomical/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from nibabies import config
from nibabies.workflows.anatomical.brain_extraction import init_infant_brain_extraction_wf
from nibabies.workflows.anatomical.outputs import init_anat_reports_wf
from nibabies.workflows.anatomical.preproc import init_anat_preproc_wf
from nibabies.workflows.anatomical.preproc import init_anat_preproc_wf, init_csf_norm_wf
from nibabies.workflows.anatomical.registration import init_coregistration_wf
from nibabies.workflows.anatomical.segmentation import init_segmentation_wf
from nibabies.workflows.anatomical.surfaces import init_mcribs_dhcp_wf
Expand Down Expand Up @@ -184,6 +184,14 @@ def init_infant_anat_fit_wf(
name='anat_buffer',
)

# Additional buffer if CSF normalization is used
anat_preproc_buffer = pe.Node(
niu.IdentityInterface(fields=['anat_preproc']),
name='anat_preproc_buffer',
)
if not config.workflow.norm_csf:
workflow.connect(anat_buffer, 'anat_preproc', anat_preproc_buffer, 'anat_preproc')

if reference_anat == 'T1w':
LOGGER.info('ANAT: Using T1w as the reference anatomical')
workflow.connect([
Expand Down Expand Up @@ -248,7 +256,7 @@ def init_infant_anat_fit_wf(
msm_buffer = pe.Node(niu.IdentityInterface(fields=['sphere_reg_msm']), name='msm_buffer')

workflow.connect([
(anat_buffer, outputnode, [
(anat_preproc_buffer, outputnode, [
('anat_preproc', 'anat_preproc'),
]),
(refined_buffer, outputnode, [
Expand Down Expand Up @@ -637,24 +645,6 @@ def init_infant_anat_fit_wf(
(binarize_t2w, t2w_buffer, [('out_file', 't2w_mask')]),
]) # fmt:skip
else:
# Check whether we can convert a previously computed T2w mask
# or need to run the atlas based brain extraction

# if t1w_mask:
# LOGGER.info('ANAT T1w mask will be transformed into T2w space')
# transform_t1w_mask = pe.Node(
# ApplyTransforms(interpolation='MultiLabel'),
# name='transform_t1w_mask',
# )

# workflow.connect([
# (t1w_buffer, transform_t1w_mask, [('t1w_mask', 'input_image')]),
# (coreg_buffer, transform_t1w_mask, [('t1w2t2w_xfm', 'transforms')]),
# (transform_t1w_mask, apply_t2w_mask, [('output_image', 'in_mask')]),
# (t2w_buffer, apply_t1w_mask, [('t2w_preproc', 'in_file')]),
# # TODO: Unsure about this connection^
# ]) # fmt:skip
# else:
LOGGER.info('ANAT Atlas-based brain mask will be calculated on the T2w')
brain_extraction_wf = init_infant_brain_extraction_wf(
omp_nthreads=omp_nthreads,
Expand Down Expand Up @@ -898,6 +888,15 @@ def init_infant_anat_fit_wf(
anat2std_buffer.inputs.in1 = [xfm['forward'] for xfm in found_xfms.values()]
std2anat_buffer.inputs.in1 = [xfm['reverse'] for xfm in found_xfms.values()]

if config.workflow.norm_csf:
csf_norm_wf = init_csf_norm_wf()

workflow.connect([
(anat_buffer, csf_norm_wf, [('anat_preproc', 'inputnode.anat_preproc')]),
(seg_buffer, csf_norm_wf, [('anat_tpms', 'inputnode.anat_tpms')]),
(csf_norm_wf, anat_preproc_buffer, [('outputnode.anat_preproc', 'anat_preproc')]),
]) # fmt:skip

if templates:
LOGGER.info(f'ANAT Stage 5: Preparing normalization workflow for {templates}')
register_template_wf = init_register_template_wf(
Expand All @@ -913,7 +912,9 @@ def init_infant_anat_fit_wf(

workflow.connect([
(inputnode, register_template_wf, [('roi', 'inputnode.lesion_mask')]),
(anat_buffer, register_template_wf, [('anat_preproc', 'inputnode.moving_image')]),
(anat_preproc_buffer, register_template_wf, [
('anat_preproc', 'inputnode.moving_image'),
]),
(refined_buffer, register_template_wf, [('anat_mask', 'inputnode.moving_mask')]),
(sourcefile_buffer, ds_template_registration_wf, [
('anat_source_files', 'inputnode.source_files')
Expand Down Expand Up @@ -1106,7 +1107,7 @@ def init_infant_anat_fit_wf(
(seg_buffer, refinement_wf, [
('ants_segs', 'inputnode.ants_segs'), # TODO: Verify this is the same as dseg
]),
(anat_buffer, applyrefined, [('anat_preproc', 'in_file')]),
(anat_preproc_buffer, applyrefined, [('anat_preproc', 'in_file')]),
(refinement_wf, applyrefined, [('outputnode.out_brainmask', 'in_mask')]),
(refinement_wf, refined_buffer, [('outputnode.out_brainmask', 'anat_mask')]),
(applyrefined, refined_buffer, [('out_file', 'anat_brain')]),
Expand Down Expand Up @@ -1384,6 +1385,14 @@ def init_infant_single_anat_fit_wf(
name='anat_buffer',
)

# Additional buffer if CSF normalization is used
anat_preproc_buffer = pe.Node(
niu.IdentityInterface(fields=['anat_preproc']),
name='anat_preproc_buffer',
)
if not config.workflow.norm_csf:
workflow.connect(anat_buffer, 'anat_preproc', anat_preproc_buffer, 'anat_preproc')

aseg_buffer = pe.Node(
niu.IdentityInterface(fields=['anat_aseg']),
name='aseg_buffer',
Expand Down Expand Up @@ -1423,7 +1432,7 @@ def init_infant_single_anat_fit_wf(
msm_buffer = pe.Node(niu.IdentityInterface(fields=['sphere_reg_msm']), name='msm_buffer')

workflow.connect([
(anat_buffer, outputnode, [
(anat_preproc_buffer, outputnode, [
('anat_preproc', 'anat_preproc'),
]),
(refined_buffer, outputnode, [
Expand Down Expand Up @@ -1724,6 +1733,15 @@ def init_infant_single_anat_fit_wf(
anat2std_buffer.inputs.in1 = [xfm['forward'] for xfm in found_xfms.values()]
std2anat_buffer.inputs.in1 = [xfm['reverse'] for xfm in found_xfms.values()]

if config.workflow.norm_csf:
csf_norm_wf = init_csf_norm_wf()

workflow.connect([
(anat_buffer, csf_norm_wf, [('anat_preproc', 'inputnode.anat_preproc')]),
(seg_buffer, csf_norm_wf, [('anat_tpms', 'inputnode.anat_tpms')]),
(csf_norm_wf, anat_preproc_buffer, [('outputnode.anat_preproc', 'anat_preproc')]),
]) # fmt:skip

if templates:
LOGGER.info(f'ANAT Stage 4: Preparing normalization workflow for {templates}')
register_template_wf = init_register_template_wf(
Expand All @@ -1739,7 +1757,9 @@ def init_infant_single_anat_fit_wf(

workflow.connect([
(inputnode, register_template_wf, [('roi', 'inputnode.lesion_mask')]),
(anat_buffer, register_template_wf, [('anat_preproc', 'inputnode.moving_image')]),
(anat_preproc_buffer, register_template_wf, [
('anat_preproc', 'inputnode.moving_image'),
]),
(refined_buffer, register_template_wf, [('anat_mask', 'inputnode.moving_mask')]),
(sourcefile_buffer, ds_template_registration_wf, [
('anat_source_files', 'inputnode.source_files')
Expand Down Expand Up @@ -1921,7 +1941,7 @@ def init_infant_single_anat_fit_wf(
(seg_buffer, refinement_wf, [
('ants_segs', 'inputnode.ants_segs'),
]),
(anat_buffer, applyrefined, [('anat_preproc', 'in_file')]),
(anat_preproc_buffer, applyrefined, [('anat_preproc', 'in_file')]),
(refinement_wf, applyrefined, [('outputnode.out_brainmask', 'in_mask')]),
(refinement_wf, refined_buffer, [('outputnode.out_brainmask', 'anat_mask')]),
(applyrefined, refined_buffer, [('out_file', 'anat_brain')]),
Expand Down
53 changes: 53 additions & 0 deletions nibabies/workflows/anatomical/preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,56 @@ def init_anat_preproc_wf(
(final_clip, outputnode, [('out_file', 'anat_preproc')]),
]) # fmt:skip
return wf


def init_csf_norm_wf(name: str = 'csf_norm_wf') -> LiterateWorkflow:
"""Replace low intensity voxels within the CSF mask with the median value."""

workflow = LiterateWorkflow(name=name)
workflow.__desc__ = (
'The CSF mask was used to normalize the anatomical template by the median of voxels '
'within the mask.'
)
inputnode = pe.Node(
niu.IdentityInterface(fields=['anat_preproc', 'anat_tpms']),
name='inputnode',
)
outputnode = pe.Node(niu.IdentityInterface(fields=['anat_preproc']), name='outputnode')

# select CSF from BIDS-ordered list (GM, WM, CSF)
select_csf = pe.Node(niu.Select(index=2), name='select_csf')
norm_csf = pe.Node(niu.Function(function=_normalize_roi), name='norm_csf')

workflow.connect([
(inputnode, select_csf, [('anat_tpms', 'inlist')]),
(select_csf, norm_csf, [('out', 'mask_file')]),
(inputnode, norm_csf, [('anat_preproc', 'in_file')]),
(norm_csf, outputnode, [('out', 'anat_preproc')]),
]) # fmt:skip

return workflow


def _normalize_roi(in_file, mask_file, threshold=0.2, out_file=None):
"""Normalize low intensity voxels that fall within a given mask."""
import nibabel as nb
import numpy as np

img = nb.load(in_file)
img_data = np.asanyarray(img.dataobj)
mask_img = nb.load(mask_file)
# binary mask
bin_mask = np.asanyarray(mask_img.dataobj) > threshold
mask_data = bin_mask * img_data
masked_data = mask_data[mask_data > 0]

median = np.median(masked_data).astype(masked_data.dtype)
normed_data = np.maximum(img_data, bin_mask * median)

oimg = img.__class__(normed_data, img.affine, img.header)
if not out_file:
from nipype.utils.filemanip import fname_presuffix

out_file = fname_presuffix(in_file, suffix='normed')
oimg.to_filename(out_file)
return out_file
Empty file.
55 changes: 55 additions & 0 deletions nibabies/workflows/anatomical/tests/test_preproc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import typing as ty
from pathlib import Path

import nibabel as nb
import numpy as np
import pytest

from nibabies.workflows.anatomical.preproc import _normalize_roi, init_csf_norm_wf

EXPECTED_CSF_NORM = np.array([[[10, 73], [73, 29]], [[77, 80], [6, 16]]], dtype='uint8')


@pytest.fixture
def csf_norm_data(tmp_path) -> ty.Generator[tuple[Path, list[Path]], None, None]:
np.random.seed(10)

in_file = tmp_path / 'input.nii.gz'
data = np.random.randint(1, 101, size=(2, 2, 2), dtype='uint8')
img = nb.Nifti1Image(data, np.eye(4))
img.to_filename(in_file)

masks = []
for tpm in ('gm', 'wm', 'csf'):
name = tmp_path / f'{tpm}.nii.gz'
binmask = data > np.random.randint(10, 90)
masked = (binmask * 1).astype('uint8')
mask = nb.Nifti1Image(masked, img.affine)
mask.to_filename(name)
masks.append(name)

yield in_file, masks

in_file.unlink()
for m in masks:
m.unlink()


def test_csf_norm_wf(tmp_path, csf_norm_data):
anat, tpms = csf_norm_data
wf = init_csf_norm_wf()
wf.base_dir = tmp_path

wf.inputs.inputnode.anat_preproc = anat
wf.inputs.inputnode.anat_tpms = tpms

# verify workflow runs
wf.run()

# verify function works as expected
outfile = _normalize_roi(anat, tpms[2])
assert np.array_equal(
np.asanyarray(nb.load(outfile).dataobj),
EXPECTED_CSF_NORM,
)
Path(outfile).unlink()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
"requests",
"sdcflows >= 2.10.0",
# "smriprep >= 0.16.1",
"smriprep @ git+https://github.com/nipreps/smriprep.git@master",
"smriprep @ git+https://github.com/nipreps/smriprep.git@dev-nibabies",
"tedana >= 23.0.2",
"templateflow >= 24.2.0",
"toml",
Expand Down

0 comments on commit b9efa6f

Please sign in to comment.