Skip to content

Commit

Permalink
[ENH] Add a series of general-purpose and emc-related image-handling …
Browse files Browse the repository at this point in the history
…helper functions in a new module utils/images.py, and relocate dangling image helper functions that were previously in interfaces/images.py into this module
  • Loading branch information
dPys committed Feb 11, 2020
1 parent 7e194c3 commit 5f5ebb1
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 36 deletions.
57 changes: 21 additions & 36 deletions dmriprep/interfaces/images.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Image tools interfaces."""
import numpy as np
import nibabel as nb
from dmriprep.utils.images import rescale_b0, median, match_transforms
from nipype.utils.filemanip import fname_presuffix
from nipype import logging
from nipype.interfaces.base import (
traits, TraitedSpec, BaseInterfaceInputSpec, SimpleInterface, File
traits, TraitedSpec, BaseInterfaceInputSpec, SimpleInterface, File,
InputMultiObject, OutputMultiObject
)

LOGGER = logging.getLogger('nipype.interface')
Expand Down Expand Up @@ -103,43 +105,26 @@ def _run_interface(self, runtime):
return runtime


def rescale_b0(in_file, mask_file, newpath=None):
"""Rescale the input volumes using the median signal intensity."""
out_file = fname_presuffix(
in_file, suffix='_rescaled_b0', newpath=newpath)
class MatchTransformsInputSpec(BaseInterfaceInputSpec):
b0_indices = traits.List(mandatory=True)
dwi_files = InputMultiObject(File(exists=True), mandatory=True)
transforms = InputMultiObject(File(exists=True), mandatory=True)

img = nb.load(in_file)
if img.dataobj.ndim == 3:
return in_file

data = img.get_fdata(dtype='float32')
mask_img = nb.load(mask_file)
mask_data = mask_img.get_fdata(dtype='float32')
class MatchTransformsOutputSpec(TraitedSpec):
transforms = OutputMultiObject(File(exists=True), mandatory=True)

median_signal = np.median(data[mask_data > 0, ...], axis=0)
rescaled_data = 1000 * data / median_signal
hdr = img.header.copy()
nb.Nifti1Image(rescaled_data, img.affine, hdr).to_filename(out_file)
return out_file

class MatchTransforms(SimpleInterface):
input_spec = MatchTransformsInputSpec
output_spec = MatchTransformsOutputSpec

def median(in_file, newpath=None):
"""Average a 4D dataset across the last dimension using median."""
out_file = fname_presuffix(
in_file, suffix='_b0ref', newpath=newpath)

img = nb.load(in_file)
if img.dataobj.ndim == 3:
return in_file
if img.shape[-1] == 1:
nb.squeeze_image(img).to_filename(out_file)
return out_file

median_data = np.median(img.get_fdata(dtype='float32'),
axis=-1)

hdr = img.header.copy()
hdr.set_xyzt_units('mm')
hdr.set_data_dtype(np.float32)
nb.Nifti1Image(median_data, img.affine, hdr).to_filename(out_file)
return out_file
def _run_interface(self, runtime):
"""
Interface for mapping the `match_transforms` function across lists of inputs.
"""
self._results["transforms"] = match_transforms(
self.inputs.dwi_files, self.inputs.transforms, self.inputs.b0_indices
)
return runtime

184 changes: 184 additions & 0 deletions dmriprep/utils/images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import numpy as np
import nibabel as nb
from nipype.utils.filemanip import fname_presuffix


def extract_b0(in_file, b0_ixs, newpath=None):
"""Extract the *b0* volumes from a DWI dataset."""
out_file = fname_presuffix(in_file, suffix="_b0", newpath=newpath)

img = nb.load(in_file)
data = img.get_fdata(dtype="float32")

b0 = data[..., b0_ixs]

hdr = img.header.copy()
hdr.set_data_shape(b0.shape)
hdr.set_xyzt_units("mm")
hdr.set_data_dtype(np.float32)
nb.Nifti1Image(b0, img.affine, hdr).to_filename(out_file)
return out_file


def rescale_b0(in_file, mask_file, newpath=None):
"""Rescale the input volumes using the median signal intensity."""
out_file = fname_presuffix(in_file, suffix="_rescaled_b0", newpath=newpath)

img = nb.load(in_file)
if img.dataobj.ndim == 3:
return in_file

data = img.get_fdata(dtype="float32")
mask_img = nb.load(mask_file)
mask_data = mask_img.get_fdata(dtype="float32")

median_signal = np.median(data[mask_data > 0, ...], axis=0)
rescaled_data = 1000 * data / median_signal
hdr = img.header.copy()
nb.Nifti1Image(rescaled_data, img.affine, hdr).to_filename(out_file)
return out_file


def median(in_file, newpath=None):
"""Average a 4D dataset across the last dimension using median."""
out_file = fname_presuffix(in_file, suffix="_b0ref", newpath=newpath)

img = nb.load(in_file)
if img.dataobj.ndim == 3:
return in_file
if img.shape[-1] == 1:
nb.squeeze_image(img).to_filename(out_file)
return out_file

median_data = np.median(img.get_fdata(dtype="float32"), axis=-1)

hdr = img.header.copy()
hdr.set_xyzt_units("mm")
hdr.set_data_dtype(np.float32)
nb.Nifti1Image(median_data, img.affine, hdr).to_filename(out_file)
return out_file


def average_images(images):
"""Average the voxel-wise signal intensity across a list of 3D image files to produce a 3D mean output image."""
from nilearn.image import mean_img

average_img = mean_img([nb.load(img) for img in images])
output_average_image = fname_presuffix(
images[0], use_ext=False, suffix="_mean.nii.gz"
)
average_img.to_filename(output_average_image)
return output_average_image


def quick_load_images(image_list, dtype=np.float32):
"""Iteratively loads 3D dwi volume files from a list of filepaths directly into a 4d array to use for signal
prediction. A helper function for EMC."""
example_img = nb.load(image_list[0])
num_images = len(image_list)
output_matrix = np.zeros(tuple(example_img.shape) + (num_images,), dtype=dtype)
for image_num, image_path in enumerate(image_list):
output_matrix[..., image_num] = nb.load(image_path).get_fdata(dtype=dtype)
return output_matrix


def match_transforms(dwi_files, transforms, b0_indices):
"""Arranges the order of a list of affine transforms to correspond with that of each individual dwi volume file,
accounting for the indices of B0s. A helper function for EMC."""
original_b0_indices = np.array(b0_indices)
num_dwis = len(dwi_files)
num_transforms = len(transforms)

if num_dwis == num_transforms:
return transforms

# Do sanity checks
if not len(transforms) == len(b0_indices):
raise Exception("number of transforms does not match number of b0 images")

# Create a list of which emc affines go with each of the split images
nearest_affines = []
for index in range(num_dwis):
nearest_b0_num = np.argmin(np.abs(index - original_b0_indices))
this_transform = transforms[nearest_b0_num]
nearest_affines.append(this_transform)

return nearest_affines


def save_4d_to_3d(in_file):
"""Loads a 4D input file and splits it in the 4th dimension to produce a list of 3D output files."""
files_3d = nb.four_to_three(nb.load(in_file))
out_files = []
for i, file_3d in enumerate(files_3d):
out_file = fname_presuffix(in_file, suffix="_tmp_{}".format(i))
file_3d.to_filename(out_file)
out_files.append(out_file)
del files_3d
return out_files


def prune_b0s_from_dwis(in_files, b0_ixs):
"""Removes B0 volume files from a complete list of dwi volume files."""
if in_files[0].endswith("_warped.nii.gz"):
out_files = [
i
for j, i in enumerate(
sorted(
in_files, key=lambda x: int(x.split("_")[-2].split(".nii.gz")[0])
)
)
if j not in b0_ixs
]
else:
out_files = [
i
for j, i in enumerate(
sorted(
in_files, key=lambda x: int(x.split("_")[-1].split(".nii.gz")[0])
)
)
if j not in b0_ixs
]
return out_files


def save_3d_to_4d(in_files):
"""Loads a list of 3D input files and concatenates it to produce a 4D output file."""
img_4d = nb.funcs.concat_images([nb.load(img_3d) for img_3d in in_files])
out_file = fname_presuffix(in_files[0], suffix="_merged")
img_4d.to_filename(out_file)
del img_4d
return out_file


def get_params(A):
"""This is a copy of spm's spm_imatrix where
we already know the rotations and translations matrix,
shears and zooms (as outputs from fsl FLIRT/avscale)
Let A = the 4x4 rotation and translation matrix
R = [ c5*c6, c5*s6, s5]
[-s4*s5*c6-c4*s6, -s4*s5*s6+c4*c6, s4*c5]
[-c4*s5*c6+s4*s6, -c4*s5*s6-s4*c6, c4*c5]
"""

def rang(b):
a = min(max(b, -1), 1)
return a

Ry = np.arcsin(A[0, 2])
# Rx = np.arcsin(A[1, 2] / np.cos(Ry))
# Rz = np.arccos(A[0, 1] / np.sin(Ry))

if (abs(Ry) - np.pi / 2) ** 2 < 1e-9:
Rx = 0
Rz = np.arctan2(-rang(A[1, 0]), rang(-A[2, 0] / A[0, 2]))
else:
c = np.cos(Ry)
Rx = np.arctan2(rang(A[1, 2] / c), rang(A[2, 2] / c))
Rz = np.arctan2(rang(A[0, 1] / c), rang(A[0, 0] / c))

rotations = [Rx, Ry, Rz]
translations = [A[0, 3], A[1, 3], A[2, 3]]

return rotations, translations

0 comments on commit 5f5ebb1

Please sign in to comment.