Skip to content

Commit

Permalink
209 loadnifti support multi-files (Project-MONAI#213)
Browse files Browse the repository at this point in the history
* [DLMED] add multi-files support for LoadNifti

* minor fixes load_nifti

* update docstrings loadnifti

Co-authored-by: Wenqi Li <[email protected]>
  • Loading branch information
Nic-Ma and wyli authored Mar 26, 2020
1 parent 434feb7 commit 5ae87f8
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 40 deletions.
10 changes: 8 additions & 2 deletions monai/transforms/composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,15 @@ def __call__(self, data):
@alias('LoadNiftiD', 'LoadNiftiDict')
class LoadNiftid(MapTransform):
"""
dictionary-based wrapper of LoadNifti, must load image and metadata together.
Dictionary-based wrapper of LoadNifti, must load image and metadata
together. If loading a list of files in one key, stack them together and
add a new dimension as the first dimension, and use the meta data of the
first image to represent the stacked result. Note that the affine transform
of all the stacked images should be same.
"""

def __init__(self, keys, as_closest_canonical=False, dtype=np.float32, meta_key_format='{}.{}', overwriting_keys=False):
def __init__(self, keys, as_closest_canonical=False, dtype=np.float32,
meta_key_format='{}.{}', overwriting_keys=False):
"""
Args:
keys (hashable items): keys of the corresponding items to be transformed.
Expand Down Expand Up @@ -794,6 +799,7 @@ class RandRotated(Randomizable, MapTransform):
cval (scalar): Value to fill outside boundary. Default: 0.
prefiter (bool): Apply spline_filter before interpolation. Default: True.
"""

def __init__(self, keys, degrees, prob=0.1, spatial_axes=(0, 1), reshape=True, order=1,
mode='constant', cval=0, prefilter=True):
MapTransform.__init__(self, keys)
Expand Down
60 changes: 37 additions & 23 deletions monai/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,10 @@ def __call__(self, data_array, original_affine=None, original_axcodes=None):
@export
class LoadNifti:
"""
Load Nifti format file from provided path.
Load Nifti format file or files from provided path. If loading a list of
files, stack them together and add a new dimension as first dimension, and
use the meta data of the first image to represent the stacked result. Note
that the affine transform of all the images should be same if ``image_only=False``.
"""

def __init__(self, as_closest_canonical=False, image_only=False, dtype=np.float32):
Expand All @@ -162,34 +165,45 @@ def __init__(self, as_closest_canonical=False, image_only=False, dtype=np.float3
def __call__(self, filename):
"""
Args:
filename (str or file): path to file or file-like object.
filename (str, list, tuple, file): path file or file-like object or a list of files.
"""
img = nib.load(filename)
img = correct_nifti_header_if_necessary(img)
filename = ensure_tuple(filename)
img_array = list()
compatible_meta = dict()
for name in filename:
img = nib.load(name)
img = correct_nifti_header_if_necessary(img)
header = dict(img.header)
header['filename_or_obj'] = name
header['original_affine'] = img.affine
header['affine'] = img.affine
header['as_closest_canonical'] = self.as_closest_canonical

header = dict(img.header)
header['filename_or_obj'] = filename
header['original_affine'] = img.affine
header['affine'] = img.affine
header['as_closest_canonical'] = self.as_closest_canonical
if self.as_closest_canonical:
img = nib.as_closest_canonical(img)
header['affine'] = img.affine

if self.as_closest_canonical:
img = nib.as_closest_canonical(img)
header['affine'] = img.affine
img_array.append(np.array(img.get_fdata(dtype=self.dtype)))
img.uncache()

if self.image_only:
continue

data = np.array(img.get_fdata(dtype=self.dtype))
img.uncache()
if not compatible_meta:
for meta_key in header:
meta_datum = header[meta_key]
if type(meta_datum).__name__ == 'ndarray' \
and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None:
continue
compatible_meta[meta_key] = meta_datum
else:
assert np.allclose(header['affine'], compatible_meta['affine']), \
'affine data of all images should be same.'

img_array = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
if self.image_only:
return data
compatible_meta = dict()
for meta_key in header:
meta_datum = header[meta_key]
if type(meta_datum).__name__ == 'ndarray' \
and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None:
continue
compatible_meta[meta_key] = meta_datum
return data, compatible_meta
return img_array
return img_array, compatible_meta


@export
Expand Down
38 changes: 29 additions & 9 deletions tests/test_load_nifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,40 +11,60 @@

import unittest
import os
import shutil
import numpy as np
import tempfile
import nibabel as nib
from parameterized import parameterized
from monai.transforms.transforms import LoadNifti

TEST_CASE_IMAGE_ONLY = [
TEST_CASE_1 = [
{
'as_closest_canonical': False,
'image_only': True
},
['test_image.nii.gz'],
(128, 128, 128)
]

TEST_CASE_IMAGE_METADATA = [
TEST_CASE_2 = [
{
'as_closest_canonical': False,
'image_only': False
},
['test_image.nii.gz'],
(128, 128, 128)
]

TEST_CASE_3 = [
{
'as_closest_canonical': False,
'image_only': True
},
['test_image1.nii.gz', 'test_image2.nii.gz', 'test_image3.nii.gz'],
(3, 128, 128, 128)
]

TEST_CASE_4 = [
{
'as_closest_canonical': False,
'image_only': False
},
['test_image1.nii.gz', 'test_image2.nii.gz', 'test_image3.nii.gz'],
(3, 128, 128, 128)
]


class TestLoadNifti(unittest.TestCase):

@parameterized.expand([TEST_CASE_IMAGE_ONLY, TEST_CASE_IMAGE_METADATA])
def test_shape(self, input_param, expected_shape):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_shape(self, input_param, filenames, expected_shape):
test_image = np.random.randint(0, 2, size=[128, 128, 128])
tempdir = tempfile.mkdtemp()
nib.save(nib.Nifti1Image(test_image, np.eye(4)), os.path.join(tempdir, 'test_image.nii.gz'))
test_data = os.path.join(tempdir, 'test_image.nii.gz')
result = LoadNifti(**input_param)(test_data)
shutil.rmtree(tempdir)
with tempfile.TemporaryDirectory() as tempdir:
for i, name in enumerate(filenames):
filenames[i] = os.path.join(tempdir, name)
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])
result = LoadNifti(**input_param)(filenames)
if isinstance(result, tuple):
result = result[0]
self.assertTupleEqual(result.shape, expected_shape)
Expand Down
11 changes: 5 additions & 6 deletions tests/test_load_niftid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import unittest
import os
import shutil
import numpy as np
import tempfile
import nibabel as nib
Expand All @@ -36,11 +35,11 @@ def test_shape(self, input_param, expected_shape):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4))
tempdir = tempfile.mkdtemp()
test_data = dict()
for key in KEYS:
nib.save(test_image, os.path.join(tempdir, key + '.nii.gz'))
test_data.update({key: os.path.join(tempdir, key + '.nii.gz')})
result = LoadNiftid(**input_param)(test_data)
shutil.rmtree(tempdir)
with tempfile.TemporaryDirectory() as tempdir:
for key in KEYS:
nib.save(test_image, os.path.join(tempdir, key + '.nii.gz'))
test_data.update({key: os.path.join(tempdir, key + '.nii.gz')})
result = LoadNiftid(**input_param)(test_data)
for key in KEYS:
self.assertTupleEqual(result[key].shape, expected_shape)

Expand Down

0 comments on commit 5ae87f8

Please sign in to comment.