diff --git a/src/mrpro/operators/_FourierOp.py b/src/mrpro/operators/_FourierOp.py index de3a2f73e..05e72a9a3 100644 --- a/src/mrpro/operators/_FourierOp.py +++ b/src/mrpro/operators/_FourierOp.py @@ -34,9 +34,9 @@ def __init__( recon_shape: SpatialDimension[int], encoding_shape: SpatialDimension[int], traj: KTrajectory, - oversampling: float | SpatialDimension[float] = 2.0, # only used for nuFFT - numpoints: int = 6, - kbwidth: float = 2.34, + nufft_oversampling: float = 2.0, + nufft_numpoints: int = 6, + nufft_kbwidth: float = 2.34, ) -> None: """Fourier Operator class. @@ -48,19 +48,15 @@ def __init__( dimension of the encoded k-space traj the k-space trajectories where the frequencies are sampled - oversampling - oversampling for (potential) nuFFT directions - numpoints - number of neighbors for interpolation for nuFFTs - kbwidth - size of the Kaiser-Bessel kernel for the nuFFT + nufft_oversampling + oversampling used for interpolation in non-uniform FFTs + nufft_numpoints + number of neighbors for interpolation in non-uniform FFTs + nufft_kbwidth + size of the Kaiser-Bessel kernel interpolation in non-uniform FFTs """ super().__init__() - # convert oversampling to SpatialDimension if float - if isinstance(oversampling, float): - oversampling = SpatialDimension(oversampling, oversampling, oversampling) - def get_spatial_dims(spatial_dims: SpatialDimension, dims: Sequence[int]): return [ s @@ -103,10 +99,7 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): ) self._nufft_im_size = get_spatial_dims(recon_shape, self._nufft_dims) - grid_size = [ - int(s * os) - for s, os in zip(self._nufft_im_size, get_spatial_dims(oversampling, self._nufft_dims), strict=True) - ] + grid_size = [int(size * nufft_oversampling) for size in self._nufft_im_size] omega = [ k * 2 * torch.pi / ks for k, ks in zip( @@ -116,21 +109,21 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]): ) ] - # Broadcast shapes (not always needed but also does not hurt) + # Broadcast shapes not always needed but also does not hurt omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega] - self._omega = torch.stack(omega, dim=-4) # use the 'coil' dim for the direction + self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction self._fwd_nufft_op = KbNufft( im_size=self._nufft_im_size, grid_size=grid_size, - numpoints=numpoints, - kbwidth=kbwidth, + numpoints=nufft_numpoints, + kbwidth=nufft_kbwidth, ) self._adj_nufft_op = KbNufftAdjoint( im_size=self._nufft_im_size, grid_size=grid_size, - numpoints=numpoints, - kbwidth=kbwidth, + numpoints=nufft_numpoints, + kbwidth=nufft_kbwidth, ) self._kshape = traj.broadcasted_shape