Skip to content

Commit

Permalink
fix: interface internal streamlining toward getting tests to pass
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Jun 30, 2023
1 parent 851df88 commit 182f3e5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
12 changes: 9 additions & 3 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,15 +383,21 @@ def _run_interface(self, runtime):

# HMC matrices are only necessary when reslicing the data (i.e., apply())
# Check the length of in_xfms matches that of in_data
self._results["out_corrected"] = unwarp.apply(
self._results["out_corrected"] = fname_presuffix(
self.inputs.in_data,
suffix="_sdc",
newpath=runtime.cwd,
)

unwarp.apply(
self.inputs.in_data,
self.inputs.pe_dir,
self.inputs.ro_time,
xfms=self.inputs.in_xfms,
xfms=self.inputs.in_xfms if isdefined(self.inputs.in_xfms) else None,
num_threads=(
None if not isdefined(self.inputs.num_threads) else self.inputs.num_threads
),
)
).to_filename(self._results["out_corrected"])
return runtime


Expand Down
7 changes: 4 additions & 3 deletions sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def apply(
cval: float = 0.0,
prefilter: bool = True,
output_dtype: Union[str, np.dtype] = None,
num_threads: int = os.cpu_count(),
num_threads: int = None,
allow_negative: bool = False,
):
"""
Expand Down Expand Up @@ -329,12 +329,13 @@ def apply(

# Prepare data
data = np.squeeze(np.asanyarray(moving.dataobj))
ndim = min(data.ndim, 3)
output_dtype = output_dtype or moving.header.get_data_dtype()

# Reference image's voxel coordinates (in voxel units)
voxcoords = nt.linear.Affine(
reference=moving
).reference.ndindex.reshape((data.ndim - 1, *data.shape[:-1])).astype("float32")
).reference.ndindex.reshape((ndim, *data.shape[:ndim])).astype("float32")

# The VSM is just the displacements field given in index coordinates
# voxcoords is the deformation field, i.e., the target position of each voxel
Expand All @@ -358,7 +359,7 @@ def apply(
mode=mode,
cval=cval,
prefilter=prefilter,
max_concurrent=num_threads,
max_concurrent=num_threads or min(os.cpu_count(), 12),
))

if not allow_negative:
Expand Down

0 comments on commit 182f3e5

Please sign in to comment.