diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index e8e1cb7598..d8415ad5bd 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -276,9 +276,7 @@ def _run_interface(self, runtime): class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec): - in_data = InputMultiObject( - File(exist=True, mandatory=True, desc="input EPI data to be corrected") - ) + in_data = File(exist=True, mandatory=True, desc="input EPI data to be corrected") in_coeff = InputMultiObject( File(exists=True), mandatory=True, @@ -288,23 +286,17 @@ class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec): exists=True, desc="the transform by which the fieldmap can be resampled on the target EPI's grid.", ) - in_xfms = InputMultiObject( - File(exists=True), desc="list of head-motion correction matrices" - ) - ro_time = InputMultiObject( - traits.Float(mandatory=True, desc="EPI readout time (s).") - ) - pe_dir = InputMultiObject( - traits.Enum( - "i", - "i-", - "j", - "j-", - "k", - "k-", - mandatory=True, - desc="the phase-encoding direction corresponding to in_data", - ) + in_xfms = File(exists=True, desc="list of head-motion correction matrices") + ro_time = traits.Float(mandatory=True, desc="EPI readout time (s).") + pe_dir = traits.Enum( + "i", + "i-", + "j", + "j-", + "k", + "k-", + mandatory=True, + desc="the phase-encoding direction corresponding to in_data", ) num_threads = traits.Int(nohash=True, desc="number of threads") approx = traits.Bool( @@ -361,18 +353,6 @@ class ApplyCoeffsField(SimpleInterface): def _run_interface(self, runtime): from sdcflows.transform import B0FieldTransform - n = len(self.inputs.in_data) - - ro_time = self.inputs.ro_time - if len(ro_time) == 1: - ro_time *= n - - pe_dir = self.inputs.pe_dir - if len(pe_dir) == 1: - pe_dir *= n - - unwarp = None - # Pre-cached interpolator object unwarp = B0FieldTransform(coeffs=[nb.load(cname) for cname in self.inputs.in_coeff]) @@ -390,12 +370,12 @@ def _run_interface(self, runtime): # self._results["out_field"] = out_field # HMC matrices are only necessary when reslicing the data (i.e., apply()) - hmc_mats = None + # Check the length of in_xfms matches that of in_data self._results["out_corrected"] = unwarp.apply( self.inputs.in_data, - pe_dir, - ro_time, - xfms=hmc_mats, + self.inputs.pe_dir, + self.input.ro_time, + xfms=self.inputs.in_xfms, # num_threads=( # None if not isdefined(self.inputs.num_threads) else self.inputs.num_threads # ),