Skip to content

Commit

Permalink
rename nufft-only-parameters and make oversampling a single float
Browse files Browse the repository at this point in the history
  • Loading branch information
fzimmermann89 authored Mar 7, 2024
1 parent ce4ba22 commit ece8e56
Showing 1 changed file with 16 additions and 23 deletions.
39 changes: 16 additions & 23 deletions src/mrpro/operators/_FourierOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def __init__(
recon_shape: SpatialDimension[int],
encoding_shape: SpatialDimension[int],
traj: KTrajectory,
oversampling: float | SpatialDimension[float] = 2.0, # only used for nuFFT
numpoints: int = 6,
kbwidth: float = 2.34,
nufft_oversampling: float = 2.0,
nufft_numpoints: int = 6,
nufft_kbwidth: float = 2.34,
) -> None:
"""Fourier Operator class.
Expand All @@ -48,19 +48,15 @@ def __init__(
dimension of the encoded k-space
traj
the k-space trajectories where the frequencies are sampled
oversampling
oversampling for (potential) nuFFT directions
numpoints
number of neighbors for interpolation for nuFFTs
kbwidth
size of the Kaiser-Bessel kernel for the nuFFT
nufft_oversampling
oversampling used for interpolation in non-uniform FFTs
nufft_numpoints
number of neighbors for interpolation in non-uniform FFTs
nufft_kbwidth
size of the Kaiser-Bessel kernel interpolation in non-uniform FFTs
"""
super().__init__()

# convert oversampling to SpatialDimension if float
if isinstance(oversampling, float):
oversampling = SpatialDimension(oversampling, oversampling, oversampling)

def get_spatial_dims(spatial_dims: SpatialDimension, dims: Sequence[int]):
return [
s
Expand Down Expand Up @@ -103,10 +99,7 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]):
)

self._nufft_im_size = get_spatial_dims(recon_shape, self._nufft_dims)
grid_size = [
int(s * os)
for s, os in zip(self._nufft_im_size, get_spatial_dims(oversampling, self._nufft_dims), strict=True)
]
grid_size = [int(size * nufft_oversampling) for size in self._nufft_im_size]
omega = [
k * 2 * torch.pi / ks
for k, ks in zip(
Expand All @@ -116,21 +109,21 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]):
)
]

# Broadcast shapes (not always needed but also does not hurt)
# Broadcast shapes not always needed but also does not hurt
omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega]
self._omega = torch.stack(omega, dim=-4) # use the 'coil' dim for the direction
self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction

self._fwd_nufft_op = KbNufft(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=numpoints,
kbwidth=kbwidth,
numpoints=nufft_numpoints,
kbwidth=nufft_kbwidth,
)
self._adj_nufft_op = KbNufftAdjoint(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=numpoints,
kbwidth=kbwidth,
numpoints=nufft_numpoints,
kbwidth=nufft_kbwidth,
)

self._kshape = traj.broadcasted_shape
Expand Down

0 comments on commit ece8e56

Please sign in to comment.