diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index 24d96dce13..7cd79e3baa 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -378,20 +378,30 @@ def _run_interface(self, runtime): ) # We can now write out the fieldmap - # unwarp.mapped.to_filename(out_field) - # self._results["out_field"] = out_field + self._results["out_field"] = fname_presuffix( + self.inputs.in_data, + suffix="_field", + newpath=runtime.cwd, + ) + unwarp.mapped.to_filename(self._results["out_field"]) # HMC matrices are only necessary when reslicing the data (i.e., apply()) # Check the length of in_xfms matches that of in_data - self._results["out_corrected"] = unwarp.apply( + self._results["out_corrected"] = fname_presuffix( + self.inputs.in_data, + suffix="_sdc", + newpath=runtime.cwd, + ) + + unwarp.apply( self.inputs.in_data, self.inputs.pe_dir, self.inputs.ro_time, - xfms=self.inputs.in_xfms, + xfms=self.inputs.in_xfms if isdefined(self.inputs.in_xfms) else None, num_threads=( None if not isdefined(self.inputs.num_threads) else self.inputs.num_threads ), - ) + ).to_filename(self._results["out_corrected"]) return runtime diff --git a/sdcflows/transform.py b/sdcflows/transform.py index 3dce1d3baf..9f527ff81a 100644 --- a/sdcflows/transform.py +++ b/sdcflows/transform.py @@ -261,7 +261,7 @@ def apply( cval: float = 0.0, prefilter: bool = True, output_dtype: Union[str, np.dtype] = None, - num_threads: int = os.cpu_count(), + num_threads: int = None, allow_negative: bool = False, ): """ @@ -329,12 +329,13 @@ def apply( # Prepare data data = np.squeeze(np.asanyarray(moving.dataobj)) + ndim = min(data.ndim, 3) output_dtype = output_dtype or moving.header.get_data_dtype() # Reference image's voxel coordinates (in voxel units) voxcoords = nt.linear.Affine( reference=moving - ).reference.ndindex.reshape((data.ndim - 1, *data.shape[:-1])).astype("float32") + ).reference.ndindex.reshape((ndim, *data.shape[:ndim])).astype("float32") # The VSM is just the displacements field given in index coordinates # voxcoords is the deformation field, i.e., the target position of each voxel @@ -358,7 +359,7 @@ def apply( mode=mode, cval=cval, prefilter=prefilter, - max_concurrent=num_threads, + max_concurrent=num_threads or min(os.cpu_count(), 12), )) if not allow_negative: