diff --git a/sdcflows/transform.py b/sdcflows/transform.py index 2c132258f5..b6e15f5178 100644 --- a/sdcflows/transform.py +++ b/sdcflows/transform.py @@ -65,7 +65,8 @@ def fit(self, target_reference, affine=None, approx=True): will be dropped. affine : :obj:`numpy.ndarray` Transform that maps coordinates on the target_reference on to the - fieldmap reference. + fieldmap reference (that is, the affine through which the filedmap + be resampled in register with the target_reference). approx : :obj:`bool` If ``True``, do not reconstruct the B-Spline field directly on the target (which will likely not be aligned with the fieldmap's grid), but rather use @@ -82,6 +83,7 @@ def fit(self, target_reference, affine=None, approx=True): if isinstance(target_reference, (str, bytes, Path)): target_reference = nb.load(target_reference) + approx = approx if affine is not None else False # Approximate iff affine is defined affine = affine if affine is not None else np.eye(4) target_affine = target_reference.affine.copy() @@ -102,15 +104,19 @@ def fit(self, target_reference, affine=None, approx=True): weights = [] coeffs = [] + _tmp_reference = None + hdr = target_reference.header.copy() if approx: - # Keep a copy - # target_reference_bak = target_reference + from sdcflows.utils.tools import deoblique_and_zooms + _tmp_reference = target_reference # Generate a sampling reference on the fieldmap's space that fully covers # the target_reference's grid. - # target_reference = ... - raise NotImplementedError + target_reference = deoblique_and_zooms( + listify(self.coeffs)[-1], + target_reference, + ) # Generate tensor-product B-Spline weights for level in listify(self.coeffs): @@ -125,12 +131,13 @@ def fit(self, target_reference, affine=None, approx=True): ) if approx: + from nitransforms.linear import Affine + # Interpolate fmap given on target_reference in target_reference_back # voxel locations (overwrite fmap) - raise NotImplementedError + fmap = Affine(reference=_tmp_reference).apply(fmap) # Cache - hdr = target_reference.header.copy() hdr.set_intent("estimate", name="fieldmap Hz") hdr.set_data_dtype("float32") hdr["cal_max"] = max((abs(fmap.min()), fmap.max())) @@ -140,7 +147,7 @@ def fit(self, target_reference, affine=None, approx=True): def apply( self, - data, + moving, pe_dir, ro_time, xfms=None, @@ -149,13 +156,16 @@ def apply( cval=0.0, prefilter=True, output_dtype=None, + num_threads=None, ): """ Apply a transformation to an image, resampling on the reference spatial object. + Handles parallelization to resample 4D images. + Parameters ---------- - data : `spatialimage` + moving : `spatialimage` The image object containing the data to be resampled in reference space xfms : `None` or :obj:`list` @@ -187,12 +197,13 @@ def apply( from sdcflows.utils.tools import ensure_positive_cosines # Ensure the fmap has been computed - if isinstance(spatialimage, (str, bytes, Path)): - spatialimage = nb.load(spatialimage) + if isinstance(moving, (str, bytes, Path)): + moving = nb.load(moving) - spatialimage, axcodes = ensure_positive_cosines(spatialimage) + # TODO: not sure this is necessary - instead check it matches self.mapped. + moving, axcodes = ensure_positive_cosines(moving) - self.fit(spatialimage) + self.fit(moving) fmap = self.mapped.get_fdata().copy() # Reverse mapped if reversed blips @@ -203,19 +214,19 @@ def apply( pe_axis = "ijk".index(pe_dir[0]) # Map voxel coordinates applying the VSM - if self.xfm is None: - ijk_axis = tuple([np.arange(s) for s in fmap.shape]) - voxcoords = np.array( - np.meshgrid(*ijk_axis, indexing="ij"), dtype="float32" - ).reshape(3, -1) - else: + ijk_axis = tuple([np.arange(s) for s in fmap.shape]) + voxcoords = np.array( + np.meshgrid(*ijk_axis, indexing="ij"), dtype="float32" + ).reshape(3, -1) + + if xfms is None: # Map coordinates from reference to time-step - self.xfm.reference = spatialimage - hmc_xyz = self.xfm.map(self.xfm.reference.ndcoords.T) + xfms.reference = moving + hmc_xyz = xfms.map(xfms.reference.ndcoords.T) # Convert from RAS to voxel coordinates voxcoords = ( - np.linalg.inv(self.xfm.reference.affine) - @ _as_homogeneous(np.vstack(hmc_xyz), dim=self.xfm.reference.ndim).T + np.linalg.inv(xfms.reference.affine) + @ _as_homogeneous(np.vstack(hmc_xyz), dim=xfms.reference.ndim).T )[:3, ...] # fmap * ro_time is the voxel-shift map (VSM) @@ -224,7 +235,7 @@ def apply( voxcoords[pe_axis, ...] += fmap.reshape(-1) * ro_time # Prepare data - data = np.squeeze(np.asanyarray(spatialimage.dataobj)) + data = np.squeeze(np.asanyarray(moving.dataobj)) output_dtype = output_dtype or data.dtype # Resample @@ -236,10 +247,10 @@ def apply( mode=mode, cval=cval, prefilter=prefilter, - ).reshape(spatialimage.shape) + ).reshape(moving.shape) - moved = spatialimage.__class__( - resampled, spatialimage.affine, spatialimage.header + moved = moving.__class__( + resampled, moving.affine, moving.header ) moved.header.set_data_dtype(output_dtype) return reorient_image(moved, axcodes)