Skip to content

Commit

Permalink
Merge pull request #421 from mgxd/rf/nitransforms
Browse files Browse the repository at this point in the history
FIX: Use nitransforms for most xfm handling
  • Loading branch information
mgxd authored Dec 19, 2024
2 parents b7e3fcc + d338dd0 commit e8fa8ad
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 124 deletions.
96 changes: 3 additions & 93 deletions nibabies/interfaces/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
import os
from collections.abc import Callable
from functools import partial
from pathlib import Path
from typing import TypeVar

import h5py
import nibabel as nb
import nitransforms as nt
import numpy as np
Expand All @@ -19,12 +17,13 @@
traits,
)
from nipype.utils.filemanip import fname_presuffix
from nitransforms.io.itk import ITKCompositeH5
from scipy import ndimage as ndi
from scipy.sparse import hstack as sparse_hstack
from sdcflows.transform import grid_bspline_weights
from sdcflows.utils.tools import ensure_positive_cosines

from nibabies.utils.transforms import load_transforms

R = TypeVar('R')


Expand All @@ -34,95 +33,6 @@ async def worker(job: Callable[[], R], semaphore: asyncio.Semaphore) -> R:
return await loop.run_in_executor(None, job)


def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.TransformBase:
"""Load a series of transforms as a nitransforms TransformChain
An empty list will return an identity transform
"""
if len(inverse) == 1:
inverse *= len(xfm_paths)
elif len(inverse) != len(xfm_paths):
raise ValueError('Mismatched number of transforms and inverses')

chain = None
for path, inv in zip(xfm_paths[::-1], inverse[::-1], strict=False):
path = Path(path)
if path.suffix == '.h5':
xfm = load_ants_h5(path)
else:
xfm = nt.linear.load(path)
if inv:
xfm = ~xfm
if chain is None:
chain = xfm
else:
chain += xfm
if chain is None:
chain = nt.base.TransformBase()
return chain


FIXED_PARAMS = np.array([
193.0, 229.0, 193.0, # Size
96.0, 132.0, -78.0, # Origin
1.0, 1.0, 1.0, # Spacing
-1.0, 0.0, 0.0, # Directions
0.0, -1.0, 0.0,
0.0, 0.0, 1.0,
]) # fmt:skip


def load_ants_h5(filename: Path) -> nt.base.TransformBase:
"""Load ANTs H5 files as a nitransforms TransformChain"""
# Borrowed from https://github.com/feilong/process
# process.resample.parse_combined_hdf5()
#
# Changes:
# * Tolerate a missing displacement field
# * Return the original affine without a round-trip
# * Always return a nitransforms TransformChain
#
# This should be upstreamed into nitransforms
h = h5py.File(filename)
xform = ITKCompositeH5.from_h5obj(h)

# nt.Affine
transforms = [nt.Affine(xform[0].to_ras())]

if '2' not in h['TransformGroup']:
return transforms[0]

transform2 = h['TransformGroup']['2']

# Confirm these transformations are applicable
if transform2['TransformType'][:][0] not in (
b'DisplacementFieldTransform_float_3_3',
b'DisplacementFieldTransform_double_3_3',
):
msg = 'Unknown transform type [2]\n'
for i in h['TransformGroup'].keys():
msg += f'[{i}]: {h["TransformGroup"][i]["TransformType"][:][0]}\n'
raise ValueError(msg)

fixed_params = transform2['TransformFixedParameters'][:]
shape = tuple(fixed_params[:3].astype(int))
# ITK stores warps in Fortran-order, where the vector components change fastest
# Nitransforms expects 3 volumes, not a volume of three-vectors, so transpose
warp = np.reshape(
transform2['TransformParameters'],
(3, *shape),
order='F',
).transpose(1, 2, 3, 0)

warp_affine = np.eye(4)
warp_affine[:3, :3] = fixed_params[9:].reshape((3, 3))
warp_affine[:3, 3] = fixed_params[3:6]
lps_to_ras = np.eye(4) * np.array([-1, -1, 1, 1])
warp_affine = lps_to_ras @ warp_affine
transforms.insert(0, nt.DenseFieldTransform(nb.Nifti1Image(warp, warp_affine)))
return nt.TransformChain(transforms)


class ResampleSeriesInputSpec(TraitedSpec):
in_file = File(exists=True, mandatory=True, desc='3D or 4D image file to resample')
ref_file = File(exists=True, mandatory=True, desc='File to resample in_file to')
Expand Down Expand Up @@ -788,7 +698,7 @@ def reconstruct_fieldmap(
)

if not direct:
fmap_img = transforms.apply(fmap_img, reference=target)
fmap_img = nt.apply(transforms, fmap_img, reference=target)

fmap_img.header.set_intent('estimate', name='fieldmap Hz')
fmap_img.header.set_data_dtype('float32')
Expand Down
34 changes: 34 additions & 0 deletions nibabies/utils/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Utilities for loading transforms for resampling"""

from pathlib import Path

import nitransforms as nt


def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.TransformBase:
"""Load a series of transforms as a nitransforms TransformChain
An empty list will return an identity transform
"""
if len(inverse) == 1:
inverse *= len(xfm_paths)
elif len(inverse) != len(xfm_paths):
raise ValueError('Mismatched number of transforms and inverses')

chain = None
for path, inv in zip(xfm_paths[::-1], inverse[::-1], strict=False):
path = Path(path)
if path.suffix == '.h5':
# Load as a TransformChain
xfm = nt.manip.load(path)
else:
xfm = nt.linear.load(path)
if inv:
xfm = ~xfm
if chain is None:
chain = xfm
else:
chain += xfm
if chain is None:
chain = nt.Affine() # Identity
return chain
12 changes: 7 additions & 5 deletions nibabies/workflows/bold/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,11 +704,11 @@ def compare_xforms(lta_list, norm_threshold=15):
second transform relative to the first (default: `15`)
"""
import nitransforms as nt
from nipype.algorithms.rapidart import _calc_norm_affine
from niworkflows.interfaces.surf import load_transform

bbr_affine = load_transform(lta_list[0])
fallback_affine = load_transform(lta_list[1])
bbr_affine = nt.linear.load(lta_list[0]).matrix
fallback_affine = nt.linear.load(lta_list[1]).matrix

norm, _ = _calc_norm_affine([fallback_affine, bbr_affine], use_differences=True)

Expand Down Expand Up @@ -741,14 +741,16 @@ def _conditional_downsampling(in_file, in_mask, zoom_th=4.0):
offset = old_center - newrot.dot((newshape - 1) * 0.5)
newaffine = nb.affines.from_matvec(newrot, offset)

identity = nt.Affine()

newref = nb.Nifti1Image(np.zeros(newshape, dtype=np.uint8), newaffine)
nt.Affine(reference=newref).apply(img).to_filename(out_file)
nt.apply(identity, img, reference=newref).to_filename(out_file)

mask = nb.load(in_mask)
mask.set_data_dtype(float)
mdata = gaussian_filter(mask.get_fdata(dtype=float), scaling)
floatmask = nb.Nifti1Image(mdata, mask.affine, mask.header)
newmask = nt.Affine(reference=newref).apply(floatmask)
newmask = nt.apply(identity, floatmask, reference=newref)
hdr = newmask.header.copy()
hdr.set_data_dtype(np.uint8)
newmaskdata = (newmask.get_fdata(dtype=float) > 0.5).astype(np.uint8)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"nipype >= 1.8.5",
"nireports >= 23.2.0",
"nitime",
"nitransforms >= 23.0.1",
"nitransforms >= 24.1.1",
"niworkflows >= 1.12.1",
"numpy >= 1.21.0",
"packaging",
Expand Down
36 changes: 11 additions & 25 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ annexremote==1.6.6
# datalad-osf
astor==0.8.1
# via formulaic
attrs==24.2.0
attrs==24.3.0
# via
# jsonschema
# niworkflows
Expand All @@ -31,16 +31,14 @@ bidsschematools==1.0.0
# via bids-validator
bokeh==3.5.2
# via tedana
boto3==1.35.80
boto3==1.35.83
# via datalad
botocore==1.35.80
botocore==1.35.83
# via
# boto3
# s3transfer
certifi==2024.8.30
certifi==2024.12.14
# via requests
cffi==1.17.1
# via cryptography
chardet==5.2.0
# via datalad
charset-normalizer==3.4.0
Expand All @@ -58,11 +56,9 @@ contourpy==1.3.1
# via
# bokeh
# matplotlib
cryptography==44.0.0
# via secretstorage
cycler==0.12.1
# via matplotlib
datalad==1.1.4
datalad==1.1.5
# via
# datalad-next
# datalad-osf
Expand All @@ -87,8 +83,6 @@ formulaic==0.5.2
# via pybids
fsspec==2024.10.0
# via universal-pathlib
greenlet==3.1.1
# via sqlalchemy
h5py==3.12.1
# via nitransforms
humanize==4.11.0
Expand Down Expand Up @@ -123,10 +117,6 @@ jaraco-context==6.0.1
# keyrings-alt
jaraco-functools==4.1.0
# via keyring
jeepney==0.8.0
# via
# keyring
# secretstorage
jinja2==3.1.4
# via
# bokeh
Expand Down Expand Up @@ -171,7 +161,7 @@ mapca==0.0.5
# via tedana
markupsafe==3.0.2
# via jinja2
matplotlib==3.9.3
matplotlib==3.10.0
# via
# nireports
# nitime
Expand Down Expand Up @@ -214,7 +204,7 @@ nilearn==0.10.4
# nireports
# niworkflows
# tedana
nipype==1.9.1
nipype==1.9.2
# via
# nibabies (pyproject.toml)
# nireports
Expand All @@ -225,7 +215,7 @@ nireports==24.0.3
# via nibabies (pyproject.toml)
nitime==0.11
# via nibabies (pyproject.toml)
nitransforms==24.1.0
nitransforms==24.1.1
# via
# nibabies (pyproject.toml)
# niworkflows
Expand All @@ -235,7 +225,7 @@ niworkflows==1.12.1
# nibabies (pyproject.toml)
# sdcflows
# smriprep
num2words==0.5.13
num2words==0.5.14
# via pybids
numpy==2.1.1
# via
Expand Down Expand Up @@ -327,8 +317,6 @@ pybtex==0.24.0
# via tedana
pybtex-apa-style==1.3
# via tedana
pycparser==2.22
# via cffi
pydot==3.0.3
# via nipype
pyparsing==3.2.0
Expand All @@ -343,7 +331,7 @@ python-dateutil==2.9.0.post0
# nipype
# pandas
# prov
python-gitlab==5.1.0
python-gitlab==5.2.0
# via datalad
pytz==2024.2
# via pandas
Expand Down Expand Up @@ -384,7 +372,7 @@ rpds-py==0.22.3
# referencing
s3transfer==0.10.4
# via boto3
scikit-image==0.24.0
scikit-image==0.25.0
# via
# niworkflows
# sdcflows
Expand Down Expand Up @@ -415,8 +403,6 @@ seaborn==0.13.2
# via
# nireports
# niworkflows
secretstorage==3.3.3
# via keyring
simplejson==3.19.3
# via nipype
six==1.17.0
Expand Down

0 comments on commit e8fa8ad

Please sign in to comment.