Skip to content

Commit

Permalink
ENH: Use 2 step registration for adult templates
Browse files Browse the repository at this point in the history
  • Loading branch information
mgxd committed Dec 6, 2024
1 parent 8cc8db2 commit 70942bc
Show file tree
Hide file tree
Showing 3 changed files with 439 additions and 14 deletions.
86 changes: 84 additions & 2 deletions nibabies/interfaces/patches.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import nipype.interfaces.freesurfer as fs
from nipype.interfaces.base import File, traits
from pathlib import Path

from nipype.interfaces import (
freesurfer as fs,
)
from nipype.interfaces.ants.base import ANTSCommand, ANTSCommandInputSpec
from nipype.interfaces.base import File, InputMultiObject, TraitedSpec, traits


class _MRICoregInputSpec(fs.registration.MRICoregInputSpec):
Expand All @@ -23,3 +28,80 @@ class MRICoreg(fs.MRICoreg):
"""

input_spec = _MRICoregInputSpec


class ConcatXFMInputSpec(ANTSCommandInputSpec):
transforms = InputMultiObject(
traits.Either(File(exists=True), 'identity'),
argstr='%s',
mandatory=True,
desc='transform files: will be applied in reverse order. For '
'example, the last specified transform will be applied first.',
)
out_xfm = traits.File(
'concat_xfm.h5',
usedefault=True,
argstr='--output [ %s, 1 ]',
desc='output file name',
)
reference_image = File(
argstr='--reference-image %s',
mandatory=True,
desc='reference image space that you wish to warp INTO',
exists=True,
)
invert_transform_flags = InputMultiObject(traits.Bool())


class ConcatXFMOutputSpec(TraitedSpec):
out_xfm = File(desc='Combined transform')


class ConcatXFM(ANTSCommand):
"""
Streamed use of antsApplyTransforms to combine nonlinear xfms into a single file
Examples
--------
>>> from nibabies.interfaces.patches import ConcatXFM
>>> cxfm = ConcatXFM()
>>> cxfm.inputs.transforms = ['xfm1.h5', 'xfm0.h5']
>>> cxfm.inputs.reference_image = 'sub-01_T1w.nii.gz'
>>> cxfm.cmdline
'antsApplyTransforms --output [ concat_xfm.h5, 1 ] --transform .../xfm1.h5 \
--transform .../xfm0.h5 --reference_image .../sub-01_T1w.nii.gz'
"""

_cmd = 'antsApplyTransforms'
input_spec = ConcatXFMInputSpec
output_spec = ConcatXFMOutputSpec

def _get_transform_filenames(self):
retval = []
invert_flags = self.inputs.invert_transform_flags

Check warning on line 83 in nibabies/interfaces/patches.py

View check run for this annotation

Codecov / codecov/patch

nibabies/interfaces/patches.py#L82-L83

Added lines #L82 - L83 were not covered by tests
if not invert_flags:
invert_flags = [False] * len(self.inputs.transforms)

Check warning on line 85 in nibabies/interfaces/patches.py

View check run for this annotation

Codecov / codecov/patch

nibabies/interfaces/patches.py#L85

Added line #L85 was not covered by tests
elif len(self.inputs.transforms) != len(invert_flags):
raise ValueError(

Check warning on line 87 in nibabies/interfaces/patches.py

View check run for this annotation

Codecov / codecov/patch

nibabies/interfaces/patches.py#L87

Added line #L87 was not covered by tests
'ERROR: The invert_transform_flags list must have the same number '
'of entries as the transforms list.'
)

for transform, invert in zip(self.inputs.transforms, invert_flags, strict=False):
if invert:
retval.append(f'--transform [ {transform}, 1 ]')

Check warning on line 94 in nibabies/interfaces/patches.py

View check run for this annotation

Codecov / codecov/patch

nibabies/interfaces/patches.py#L94

Added line #L94 was not covered by tests
else:
retval.append(f'--transform {transform}')
return ' '.join(retval)

Check warning on line 97 in nibabies/interfaces/patches.py

View check run for this annotation

Codecov / codecov/patch

nibabies/interfaces/patches.py#L96-L97

Added lines #L96 - L97 were not covered by tests

def _format_arg(self, opt, spec, val):
if opt == 'transforms':
return self._get_transform_filenames()
return super()._format_arg(opt, spec, val)

Check warning on line 102 in nibabies/interfaces/patches.py

View check run for this annotation

Codecov / codecov/patch

nibabies/interfaces/patches.py#L101-L102

Added lines #L101 - L102 were not covered by tests

def _list_outputs(self):
outputs = self._outputs().get()
outputs['out_xfm'] = Path(self.inputs.out_xfm).absolute()
return outputs

Check warning on line 107 in nibabies/interfaces/patches.py

View check run for this annotation

Codecov / codecov/patch

nibabies/interfaces/patches.py#L105-L107

Added lines #L105 - L107 were not covered by tests
166 changes: 155 additions & 11 deletions nibabies/workflows/anatomical/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms
from niworkflows.interfaces.header import ValidateImage
from niworkflows.interfaces.nibabel import ApplyMask, Binarize
from niworkflows.interfaces.utility import KeySelect
from niworkflows.utils.connections import pop_file
from smriprep.workflows.anatomical import (
_is_skull_stripped,
Expand Down Expand Up @@ -37,7 +38,10 @@
from nibabies.workflows.anatomical.brain_extraction import init_infant_brain_extraction_wf
from nibabies.workflows.anatomical.outputs import init_anat_reports_wf
from nibabies.workflows.anatomical.preproc import init_anat_preproc_wf
from nibabies.workflows.anatomical.registration import init_coregistration_wf
from nibabies.workflows.anatomical.registration import (
init_concat_registrations_wf,
init_coregistration_wf,
)
from nibabies.workflows.anatomical.segmentation import init_segmentation_wf
from nibabies.workflows.anatomical.surfaces import init_mcribs_dhcp_wf

Expand Down Expand Up @@ -220,9 +224,9 @@ def init_infant_anat_fit_wf(
name='seg_buffer',
)
# Stage 5 - collated template names, forward and reverse transforms
template_buffer = pe.Node(niu.Merge(2), name='template_buffer')
anat2std_buffer = pe.Node(niu.Merge(2), name='anat2std_buffer')
std2anat_buffer = pe.Node(niu.Merge(2), name='std2anat_buffer')
template_buffer = pe.Node(niu.Merge(3), name='template_buffer')
anat2std_buffer = pe.Node(niu.Merge(3), name='anat2std_buffer')
std2anat_buffer = pe.Node(niu.Merge(3), name='std2anat_buffer')

Check warning on line 229 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L227-L229

Added lines #L227 - L229 were not covered by tests

# Stage 6 results: Refined stage 2 results; may be direct copy if no refinement
refined_buffer = pe.Node(
Expand Down Expand Up @@ -885,21 +889,48 @@ def init_infant_anat_fit_wf(
seg_buffer.inputs.anat_tpms = anat_tpms

# Stage 5: Normalization

# If selected adult templates are requested (MNI152 6th Gen or 2009)
# opt to concatenate transforms first from native -> infant template (MNIInfant),
# and then use a previously computed MNIInfant<cohort> -> MNI transform
# this minimizes the chance of a bad registration.

templates = []
concat_xfms = []

Check warning on line 899 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L899

Added line #L899 was not covered by tests
found_xfms = {}
intermediate = None # The intermediate space when concatenating xfms - includes cohort
intermediate_targets = {

Check warning on line 902 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L901-L902

Added lines #L901 - L902 were not covered by tests
'MNI152NLin6Asym',
} # TODO: 'MNI152NLin2009cAsym'

for template in spaces.get_spaces(nonstandard=False, dim=(3,)):
# resolution / spec will not differentiate here
if template.startswith('MNIInfant'):
intermediate = template

Check warning on line 909 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L909

Added line #L909 was not covered by tests
xfms = precomputed.get('transforms', {}).get(template, {})
if set(xfms) != {'forward', 'reverse'}:
templates.append(template)
if template in intermediate_targets:
concat_xfms.append(template)

Check warning on line 913 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L913

Added line #L913 was not covered by tests
else:
templates.append(template)

Check warning on line 915 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L915

Added line #L915 was not covered by tests
else:
found_xfms[template] = xfms

# Create another set of buffers to handle the case where we aggregate found and generated
# xfms to be concatenated
concat_template_buffer = pe.Node(niu.Merge(2), name='concat_template_buffer')
concat_template_buffer.inputs.in1 = list(found_xfms)
concat_anat2std_buffer = pe.Node(niu.Merge(2), name='concat_anat2std_buffer')
concat_anat2std_buffer.inputs.in1 = [xfm['forward'] for xfm in found_xfms.values()]
concat_std2anat_buffer = pe.Node(niu.Merge(2), name='concat_std2anat_buffer')
concat_std2anat_buffer.inputs.in1 = [xfm['reverse'] for xfm in found_xfms.values()]

Check warning on line 926 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L921-L926

Added lines #L921 - L926 were not covered by tests

template_buffer.inputs.in1 = list(found_xfms)
anat2std_buffer.inputs.in1 = [xfm['forward'] for xfm in found_xfms.values()]
std2anat_buffer.inputs.in1 = [xfm['reverse'] for xfm in found_xfms.values()]

if templates:
LOGGER.info(f'ANAT Stage 5: Preparing normalization workflow for {templates}')
LOGGER.info(f'ANAT Stage 5a: Preparing normalization workflow for {templates}')

Check warning on line 933 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L933

Added line #L933 was not covered by tests
register_template_wf = init_register_template_wf(
sloppy=sloppy,
omp_nthreads=omp_nthreads,
Expand All @@ -923,11 +954,53 @@ def init_infant_anat_fit_wf(
('outputnode.std2anat_xfm', 'inputnode.std2anat_xfm'),
]),
(register_template_wf, template_buffer, [('outputnode.template', 'in2')]),
(register_template_wf, concat_template_buffer, [('outputnode.template', 'in2')]),
(register_template_wf, std2anat_buffer, [('outputnode.std2anat_xfm', 'in2')]),
(register_template_wf, concat_std2anat_buffer, [('outputnode.std2anat_xfm', 'in2')]),
(register_template_wf, anat2std_buffer, [('outputnode.anat2std_xfm', 'in2')]),
(register_template_wf, concat_anat2std_buffer, [('outputnode.anat2std_xfm', 'in2')]),
]) # fmt:skip

if concat_xfms:
LOGGER.info(f'ANAT Stage 5b: Concatenating normalization for {concat_xfms}')

Check warning on line 965 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L965

Added line #L965 was not covered by tests
# 1. Select intermediate's transforms
select_infant_mni = pe.Node(

Check warning on line 967 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L967

Added line #L967 was not covered by tests
KeySelect(fields=['template', 'anat2std_xfm', 'std2anat_xfm'], key=intermediate),
name='select_infant_mni',
run_without_submitting=True,
)
concat_reg_wf = init_concat_registrations_wf(templates=concat_xfms)
ds_concat_reg_wf = init_ds_template_registration_wf(

Check warning on line 973 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L972-L973

Added lines #L972 - L973 were not covered by tests
output_dir=str(output_dir),
image_type=reference_anat,
name='ds_concat_registration_wf',
)

workflow.connect([

Check warning on line 979 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L979

Added line #L979 was not covered by tests
(concat_template_buffer, select_infant_mni, [('out', 'template')]),
(concat_anat2std_buffer, select_infant_mni, [('out', 'anat2std_xfm')]),
(concat_std2anat_buffer, select_infant_mni, [('out', 'std2anat_xfm')]),
(select_infant_mni, concat_reg_wf, [
('template', 'inputnode.intermediate'),
('anat2std_xfm', 'inputnode.anat2std_xfm'),
('std2anat_xfm', 'inputnode.std2anat_xfm'),
]),
(anat_buffer, concat_reg_wf, [('anat_preproc', 'inputnode.anat_preproc')]),
(sourcefile_buffer, ds_concat_reg_wf, [
('anat_source_files', 'inputnode.source_files')
]),
(concat_reg_wf, ds_concat_reg_wf, [
('outputnode.template', 'inputnode.template'),
('outputnode.anat2std_xfm', 'inputnode.anat2std_xfm'),
('outputnode.std2anat_xfm', 'inputnode.std2anat_xfm'),
]),
(concat_reg_wf, template_buffer, [('outputnode.template', 'in3')]),
(concat_reg_wf, anat2std_buffer, [('outputnode.anat2std_xfm', 'in3')]),
(concat_reg_wf, std2anat_buffer, [('outputnode.std2anat_xfm', 'in3')]),
]) # fmt:skip

if found_xfms:
LOGGER.info(f'ANAT Stage 5: Found pre-computed registrations for {found_xfms}')
LOGGER.info(f'ANAT Stage 5c: Found pre-computed registrations for {found_xfms}')

Check warning on line 1003 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L1003

Added line #L1003 was not covered by tests

# Only refine mask if necessary
if anat_mask or recon_method is None or not refine_mask:
Expand Down Expand Up @@ -1394,9 +1467,9 @@ def init_infant_single_anat_fit_wf(
name='seg_buffer',
)
# Stage 4 - collated template names, forward and reverse transforms
template_buffer = pe.Node(niu.Merge(2), name='template_buffer')
anat2std_buffer = pe.Node(niu.Merge(2), name='anat2std_buffer')
std2anat_buffer = pe.Node(niu.Merge(2), name='std2anat_buffer')
template_buffer = pe.Node(niu.Merge(3), name='template_buffer')
anat2std_buffer = pe.Node(niu.Merge(3), name='anat2std_buffer')
std2anat_buffer = pe.Node(niu.Merge(3), name='std2anat_buffer')

Check warning on line 1472 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L1470-L1472

Added lines #L1470 - L1472 were not covered by tests

# Stage 5 results: Refined stage 2 results; may be direct copy if no refinement
refined_buffer = pe.Node(
Expand Down Expand Up @@ -1710,15 +1783,41 @@ def init_infant_single_anat_fit_wf(
seg_buffer.inputs.anat_tpms = anat_tpms

# Stage 4: Normalization
# If selected adult templates are requested (MNI152 6th Gen or 2009)
# opt to concatenate transforms first from native -> infant template (MNIInfant),
# and then use a previously computed MNIInfant<cohort> -> MNI transform
# this minimizes the chance of a bad registration.

templates = []
concat_xfms = []

Check warning on line 1792 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L1792

Added line #L1792 was not covered by tests
found_xfms = {}
intermediate = None # The intermediate space when concatenating xfms - includes cohort
intermediate_targets = {

Check warning on line 1795 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L1794-L1795

Added lines #L1794 - L1795 were not covered by tests
'MNI152NLin6Asym',
} # TODO: 'MNI152NLin2009cAsym'

for template in spaces.get_spaces(nonstandard=False, dim=(3,)):
# resolution / spec will not differentiate here
if template.startswith('MNIInfant'):
intermediate = template

Check warning on line 1802 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L1802

Added line #L1802 was not covered by tests
xfms = precomputed.get('transforms', {}).get(template, {})
if set(xfms) != {'forward', 'reverse'}:
templates.append(template)
if template in intermediate_targets:
concat_xfms.append(template)

Check warning on line 1806 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L1806

Added line #L1806 was not covered by tests
else:
templates.append(template)

Check warning on line 1808 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L1808

Added line #L1808 was not covered by tests
else:
found_xfms[template] = xfms

# Create another set of buffers to handle the case where we aggregate found and generated
# xfms to be concatenated
concat_template_buffer = pe.Node(niu.Merge(2), name='concat_template_buffer')
concat_template_buffer.inputs.in1 = list(found_xfms)
concat_anat2std_buffer = pe.Node(niu.Merge(2), name='concat_anat2std_buffer')
concat_anat2std_buffer.inputs.in1 = [xfm['forward'] for xfm in found_xfms.values()]
concat_std2anat_buffer = pe.Node(niu.Merge(2), name='concat_std2anat_buffer')
concat_std2anat_buffer.inputs.in1 = [xfm['reverse'] for xfm in found_xfms.values()]

Check warning on line 1819 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L1814-L1819

Added lines #L1814 - L1819 were not covered by tests

template_buffer.inputs.in1 = list(found_xfms)
anat2std_buffer.inputs.in1 = [xfm['forward'] for xfm in found_xfms.values()]
std2anat_buffer.inputs.in1 = [xfm['reverse'] for xfm in found_xfms.values()]
Expand Down Expand Up @@ -1748,9 +1847,54 @@ def init_infant_single_anat_fit_wf(
('outputnode.std2anat_xfm', 'inputnode.std2anat_xfm'),
]),
(register_template_wf, template_buffer, [('outputnode.template', 'in2')]),
(register_template_wf, concat_template_buffer, [('outputnode.template', 'in2')]),
(register_template_wf, std2anat_buffer, [('outputnode.std2anat_xfm', 'in2')]),
(register_template_wf, concat_std2anat_buffer, [('outputnode.std2anat_xfm', 'in2')]),
(register_template_wf, anat2std_buffer, [('outputnode.anat2std_xfm', 'in2')]),
(register_template_wf, concat_anat2std_buffer, [('outputnode.anat2std_xfm', 'in2')]),
]) # fmt:skip

if concat_xfms:
LOGGER.info(f'ANAT Stage 5b: Concatenating normalization for {concat_xfms}')

Check warning on line 1858 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L1858

Added line #L1858 was not covered by tests
# 1. Select intermediate's transforms
select_infant_mni = pe.Node(

Check warning on line 1860 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L1860

Added line #L1860 was not covered by tests
KeySelect(fields=['template', 'anat2std_xfm', 'std2anat_xfm']),
name='select_infant_mni',
run_without_submitting=True,
)
select_infant_mni.inputs.key = intermediate

Check warning on line 1865 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L1865

Added line #L1865 was not covered by tests

concat_reg_wf = init_concat_registrations_wf(templates=concat_xfms)
ds_concat_reg_wf = init_ds_template_registration_wf(

Check warning on line 1868 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L1867-L1868

Added lines #L1867 - L1868 were not covered by tests
output_dir=str(output_dir),
image_type=reference_anat,
name='ds_concat_registration_wf',
)

workflow.connect([

Check warning on line 1874 in nibabies/workflows/anatomical/fit.py

View check run for this annotation

Codecov / codecov/patch

nibabies/workflows/anatomical/fit.py#L1874

Added line #L1874 was not covered by tests
(concat_template_buffer, select_infant_mni, [('out', 'keys')]),
(concat_template_buffer, select_infant_mni, [('out', 'template')]),
(concat_anat2std_buffer, select_infant_mni, [('out', 'anat2std_xfm')]),
(concat_std2anat_buffer, select_infant_mni, [('out', 'std2anat_xfm')]),
(select_infant_mni, concat_reg_wf, [
('template', 'inputnode.intermediate'),
('anat2std_xfm', 'inputnode.anat2std_xfm'),
('std2anat_xfm', 'inputnode.std2anat_xfm'),
]),
(anat_buffer, concat_reg_wf, [('anat_preproc', 'inputnode.anat_preproc')]),
(sourcefile_buffer, ds_concat_reg_wf, [
('anat_source_files', 'inputnode.source_files')
]),
(concat_reg_wf, ds_concat_reg_wf, [
('outputnode.template', 'inputnode.template'),
('outputnode.anat2std_xfm', 'inputnode.anat2std_xfm'),
('outputnode.std2anat_xfm', 'inputnode.std2anat_xfm'),
]),
(concat_reg_wf, template_buffer, [('outputnode.template', 'in3')]),
(concat_reg_wf, anat2std_buffer, [('outputnode.anat2std_xfm', 'in3')]),
(concat_reg_wf, std2anat_buffer, [('outputnode.std2anat_xfm', 'in3')]),
]) # fmt:skip

if found_xfms:
LOGGER.info(f'ANAT Stage 4: Found pre-computed registrations for {found_xfms}')

Expand Down
Loading

0 comments on commit 70942bc

Please sign in to comment.