diff --git a/docs/conf.py b/docs/conf.py index 65645bbd..e3c3419b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -55,6 +55,7 @@ "pandas", "seaborn", "skimage", + "sklearn", "svgutils", "tqdm", "transforms3d", diff --git a/setup.cfg b/setup.cfg index 327ef964..f093ade8 100755 --- a/setup.cfg +++ b/setup.cfg @@ -30,9 +30,10 @@ install_requires = joblib nipype>= 1.5.1, < 2.0 nitransforms>=21.0.0 + numpy>=1.17.3 nest-asyncio>=1.5.1 scikit-image>=0.14.2 - scikit-learn>=1.0.1 + scipy>=1.8.0 test_requires = codecov coverage diff --git a/src/eddymotion/data/pet.py b/src/eddymotion/data/pet.py new file mode 100644 index 00000000..415b477c --- /dev/null +++ b/src/eddymotion/data/pet.py @@ -0,0 +1,179 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright 2022 The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +"""PET data representation.""" +from collections import namedtuple +from pathlib import Path +from tempfile import mkdtemp + +import attr +import h5py +import nibabel as nb +import numpy as np +from nitransforms.linear import Affine + + +def _data_repr(value): + if value is None: + return "None" + return f"<{'x'.join(str(v) for v in value.shape)} ({value.dtype})>" + + +@attr.s(slots=True) +class PET: + """Data representation structure for PET data.""" + + dataobj = attr.ib(default=None, repr=_data_repr) + """A numpy ndarray object for the data array, without *b=0* volumes.""" + affine = attr.ib(default=None, repr=_data_repr) + """Best affine for RAS-to-voxel conversion of coordinates (NIfTI header).""" + brainmask = attr.ib(default=None, repr=_data_repr) + """A boolean ndarray object containing a corresponding brainmask.""" + frame_time = attr.ib(default=None, repr=_data_repr) + """A 1D numpy array with the midpoint timing of each sample.""" + total_duration = attr.ib(default=None, repr=_data_repr) + """A float number represaenting the total duration of acquisition.""" + + em_affines = attr.ib(default=None) + """ + List of :obj:`nitransforms.linear.Affine` objects that bring + PET timepoints into alignment. + """ + _filepath = attr.ib( + factory=lambda: Path(mkdtemp()) / "em_cache.h5", + repr=False, + ) + """A path to an HDF5 file to store the whole dataset.""" + + def __len__(self): + """Obtain the number of high-*b* orientations.""" + return self.dataobj.shape[-1] + + def set_transform(self, index, affine, order=3): + """Set an affine, and update data object and gradients.""" + reference = namedtuple("ImageGrid", ("shape", "affine"))( + shape=self.dataobj.shape[:3], affine=self.affine + ) + xform = Affine(matrix=affine, reference=reference) + + if not Path(self._filepath).exists(): + self.to_filename(self._filepath) + + # read original PET + with h5py.File(self._filepath, "r") as in_file: + root = in_file["/0"] + dframe = np.asanyarray(root["dataobj"][..., index]) + + dmoving = nb.Nifti1Image(dframe, self.affine, None) + + # resample and update orientation at index + self.dataobj[..., index] = np.asanyarray( + xform.apply(dmoving, order=order).dataobj, + dtype=self.dataobj.dtype, + ) + + # update transform + if self.em_affines is None: + self.em_affines = [None] * len(self) + + self.em_affines[index] = xform + + def to_filename(self, filename, compression=None, compression_opts=None): + """Write an HDF5 file to disk.""" + filename = Path(filename) + if not filename.name.endswith(".h5"): + filename = filename.parent / f"{filename.name}.h5" + + with h5py.File(filename, "w") as out_file: + out_file.attrs["Format"] = "EMC/PET" + out_file.attrs["Version"] = np.uint16(1) + root = out_file.create_group("/0") + root.attrs["Type"] = "pet" + for f in attr.fields(self.__class__): + if f.name.startswith("_"): + continue + + value = getattr(self, f.name) + if value is not None: + root.create_dataset( + f.name, + data=value, + compression=compression, + compression_opts=compression_opts, + ) + + def to_nifti(self, filename, insert_b0=False): + """Write a NIfTI 1.0 file to disk.""" + nii = nb.Nifti1Image(self.dataobj, self.affine, None) + nii.header.set_xyzt_units("mm") + nii.to_filename(filename) + + @classmethod + def from_filename(cls, filename): + """Read an HDF5 file from disk.""" + with h5py.File(filename, "r") as in_file: + root = in_file["/0"] + data = { + k: np.asanyarray(v) for k, v in root.items() if not k.startswith("_") + } + return cls(**data) + + +def load( + filename, + brainmask_file=None, + frame_time=None, + frame_duration=None, +): + """Load PET data.""" + filename = Path(filename) + if filename.name.endswith(".h5"): + return PET.from_filename(filename) + + img = nb.load(filename) + retval = PET( + dataobj=img.get_fdata(dtype="float32"), + affine=img.affine, + ) + + if frame_time is None: + raise RuntimeError( + "Start time of frames is mandatory (see https://bids-specification.readthedocs.io/" + "en/stable/glossary.html#objects.metadata.FrameTimesStart)" + ) + + frame_time = np.array(frame_time, dtype="float32") - frame_time[0] + if frame_duration is None: + frame_duration = np.diff(frame_time) + if len(frame_duration) == (retval.dataobj.shape[-1] - 1): + frame_duration = np.append(frame_duration, frame_duration[-1]) + + retval.total_duration = frame_time[-1] + frame_duration[-1] + retval.frame_time = frame_time + 0.5 * np.array(frame_duration, dtype="float32") + + assert len(retval.frame_time) == retval.dataobj.shape[-1] + + if brainmask_file: + mask = nb.load(brainmask_file) + retval.brainmask = np.asanyarray(mask.dataobj) + + return retval diff --git a/src/eddymotion/estimator.py b/src/eddymotion/estimator.py index 6c7a4970..9624c1e3 100644 --- a/src/eddymotion/estimator.py +++ b/src/eddymotion/estimator.py @@ -83,16 +83,6 @@ def fit( if seed or seed == 0: np.random.seed(20210324 if seed is True else seed) - bmask_img = None - if dwdata.brainmask is not None: - _, bmask_img = mkstemp(suffix="_bmask.nii.gz") - nb.Nifti1Image( - dwdata.brainmask.astype("uint8"), dwdata.affine, None - ).to_filename(bmask_img) - kwargs["mask"] = dwdata.brainmask - - kwargs["S0"] = _advanced_clip(dwdata.bzero) - if "num_threads" not in align_kwargs and omp_nthreads is not None: align_kwargs["num_threads"] = omp_nthreads @@ -103,6 +93,28 @@ def fit( if model.lower() not in ("b0", "s0", "avg", "average", "mean") else "b0" ) + + # When downsampling these need to be set per-level + bmask_img = None + if dwdata.brainmask is not None: + _, bmask_img = mkstemp(suffix="_bmask.nii.gz") + nb.Nifti1Image( + dwdata.brainmask.astype("uint8"), dwdata.affine, None + ).to_filename(bmask_img) + kwargs["mask"] = dwdata.brainmask + + if hasattr(dwdata, "bzero") and dwdata.bzero is not None: + kwargs["S0"] = _advanced_clip(dwdata.bzero) + + if hasattr(dwdata, "gradients"): + kwargs["gtab"] = dwdata.gradients + + if hasattr(dwdata, "frame_time"): + kwargs["timepoints"] = dwdata.frame_time + + if hasattr(dwdata, "total_duration"): + kwargs["xlim"] = dwdata.total_duration + index_order = np.arange(len(dwdata)) np.random.shuffle(index_order) @@ -118,7 +130,6 @@ def fit( # Factory creates the appropriate model and pipes arguments dwmodel = ModelFactory.init( - gtab=dwdata.gradients, model=model, **kwargs, ) @@ -137,9 +148,10 @@ def fit( pbar.set_description_str(f"[{grad_str}], {n_jobs} jobs") if not single_model: # A true LOGO estimator + if hasattr(dwdata, "gradients"): + kwargs["gtab"] = data_train[1] # Factory creates the appropriate model and pipes arguments dwmodel = ModelFactory.init( - gtab=data_train[1], model=model, n_jobs=n_jobs, **kwargs, diff --git a/src/eddymotion/model/__init__.py b/src/eddymotion/model/__init__.py index 6fa40bd8..e2eb4565 100644 --- a/src/eddymotion/model/__init__.py +++ b/src/eddymotion/model/__init__.py @@ -27,6 +27,7 @@ DKIModel, DTIModel, TrivialB0Model, + PETModel, ) __all__ = ( @@ -35,4 +36,5 @@ "DKIModel", "DTIModel", "TrivialB0Model", + "PETModel", ) diff --git a/src/eddymotion/model/base.py b/src/eddymotion/model/base.py index 65b76dd6..59b7a863 100644 --- a/src/eddymotion/model/base.py +++ b/src/eddymotion/model/base.py @@ -41,7 +41,7 @@ class ModelFactory: """A factory for instantiating diffusion models.""" @staticmethod - def init(gtab, model="DTI", **kwargs): + def init(model="DTI", **kwargs): """ Instatiate a diffusion model. @@ -60,15 +60,14 @@ def init(gtab, model="DTI", **kwargs): """ if model.lower() in ("s0", "b0"): - return TrivialB0Model(gtab=gtab, S0=kwargs.pop("S0")) + return TrivialB0Model(S0=kwargs.pop("S0")) if model.lower() in ("avg", "average", "mean"): - return AverageDWModel(gtab=gtab, **kwargs) + return AverageDWModel(**kwargs) - # Generate a GradientTable object for DIPY - if model.lower() in ("dti", "dki"): + if model.lower() in ("dti", "dki", "pet"): Model = globals()[f"{model.upper()}Model"] - return Model(gtab, **kwargs) + return Model(**kwargs) raise NotImplementedError(f"Unsupported model <{model}>.") @@ -217,7 +216,7 @@ class TrivialB0Model: __slots__ = ("_S0",) - def __init__(self, gtab, S0=None, **kwargs): + def __init__(self, S0=None, **kwargs): """Implement object initialization.""" if S0 is None: raise ValueError("S0 must be provided") @@ -237,7 +236,7 @@ class AverageDWModel: __slots__ = ("_data", "_th_low", "_th_high", "_bias", "_stat") - def __init__(self, gtab, **kwargs): + def __init__(self, **kwargs): r""" Implement object initialization. @@ -293,6 +292,107 @@ def predict(self, gradient, **kwargs): return self._data +class PETModel: + """A PET imaging realignment model based on B-Spline approximation.""" + + __slots__ = ("_t", "_x", "_xlim", "_order", "_coeff", "_mask", "_shape", "_n_ctrl") + + def __init__(self, timepoints=None, xlim=None, n_ctrl=None, mask=None, order=3, **kwargs): + """ + Create the B-Spline interpolating matrix. + + Parameters: + ----------- + timepoints : :obj:`list` + The timing (in sec) of each PET volume. + E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330., + 420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]`` + + n_ctrl : :obj:`int` + Number of B-Spline control points. If `None`, then one control point every + six timepoints will be used. The less control points, the smoother is the + model. + + """ + if timepoints is None or xlim is None: + raise TypeError("timepoints must be provided in initialization") + + self._order = order + self._mask = mask + + self._x = np.array(timepoints, dtype="float32") + self._xlim = xlim + + if self._x[0] < 1e-2: + raise ValueError("First frame midpoint should not be zero or negative") + if self._x[-1] > (self._xlim - 1e-2): + raise ValueError("Last frame midpoint should not be equal or greater than duration") + + # Calculate index coordinates in the B-Spline grid + self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1 + + # B-Spline knots + self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32") + + def fit(self, data, *args, **kwargs): + """Fit the model.""" + from scipy.interpolate import BSpline + from scipy.sparse.linalg import cg + + n_jobs = kwargs.pop("n_jobs", None) or 1 + + timepoints = kwargs.get("timepoints", None) or self._x + x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl + + self._shape = data.shape[:3] + + # Convert data into V (voxels) x T (timepoints) + data = ( + data.reshape((-1, data.shape[-1])) + if self._mask is None else data[self._mask] + ) + + # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding) + A = BSpline.design_matrix(x, self._t, k=self._order) + AT = A.T + ATdotA = AT @ A + + # One single CPU - linear execution (full model) + if n_jobs == 1: + self._coeff = np.array([ + cg(ATdotA, AT @ v)[0] for v in data + ]) + return + + # Parallelize process with joblib + with Parallel(n_jobs=n_jobs) as executor: + results = executor( + delayed(cg)(ATdotA, AT @ v) + for v in data + ) + + self._coeff = np.array([r[0] for r in results]) + + def predict(self, timepoint, **kwargs): + """Return the *b=0* map.""" + from scipy.interpolate import BSpline + + # Project sample timing into B-Spline coordinates + x = (timepoint / self._xlim) * self._n_ctrl + A = BSpline.design_matrix(x, self._t, k=self._order) + + # A is 1 (num. timepoints) x C (num. coeff) + # self._coeff is V (num. voxels) x K - 4 + predicted = np.squeeze(A @ self._coeff.T) + + if self._mask is None: + return predicted.reshape(self._shape) + + retval = np.zeros(self._shape, dtype="float32") + retval[self._mask] = predicted + return retval + + class DTIModel(BaseModel): """A wrapper of :obj:`dipy.reconst.dti.TensorModel`.""" diff --git a/test/test_model.py b/test/test_model.py index 22c9b75f..4b3187d1 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -98,7 +98,6 @@ def test_two_initialisations(datadir): # Direct initialisation model1 = model.AverageDWModel( - dmri_dataset.gradients, S0=dmri_dataset.bzero, th_low=100, th_high=1000,