Skip to content

Commit

Permalink
Fftshift for FourierOp (#143)
Browse files Browse the repository at this point in the history
* FastFourierOp and PadOp added
  • Loading branch information
ckolbPTB authored Jan 26, 2024
1 parent 7f183a6 commit 93667b3
Show file tree
Hide file tree
Showing 14 changed files with 412 additions and 107 deletions.
8 changes: 4 additions & 4 deletions src/mrpro/algorithms/_remove_readout_os.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

from mrpro.data import KData
from mrpro.data import KTrajectory
from mrpro.utils.fft import image_to_kspace
from mrpro.utils.fft import kspace_to_image
from mrpro.operators import FastFourierOp


def remove_readout_os(kdata: KData) -> KData:
Expand Down Expand Up @@ -56,9 +55,10 @@ def crop_readout(input):
return input[..., start_cropped_readout:end_cropped_readout]

# Transform to image space, crop to reconstruction matrix size and transform back
dat = kspace_to_image(kdata.data, dim=(-1,))
FFOp = FastFourierOp(dim=(-1,))
dat = FFOp.adjoint(kdata.data)
dat = crop_readout(dat)
dat = image_to_kspace(dat, dim=(-1,))
dat = FFOp.forward(dat)

# Adapt trajectory
ks = [kdata.traj.kz, kdata.traj.ky, kdata.traj.kx]
Expand Down
90 changes: 90 additions & 0 deletions src/mrpro/operators/_FastFourierOp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Class for Fast Fourier Operator."""

# Copyright 2023 Physikalisch-Technische Bundesanstalt
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from mrpro.operators import LinearOperator
from mrpro.operators import ZeroPadOp


class FastFourierOp(LinearOperator):
"""Fast Fourier operator class."""

def __init__(
self,
dim: tuple[int, ...] = (-3, -2, -1),
recon_shape: tuple[int, ...] | None = None,
encoding_shape: tuple[int, ...] | None = None,
) -> None:
"""Fast Fourier Operator class.
Remark regarding the fftshift/ifftshift:
fftshift shifts the zero-frequency point to the center of the data, ifftshift undoes this operation.
The input to both forward and ajoint assumes that the zero-frequency is in the center of the data.
Torch.fft.fftn and torch.fft.ifftn expect the zero-frequency to be the first entry in the tensor.
Therefore for forward and ajoint first ifftshift needs to be applied, then fftn or ifftn and then ifftshift.
Parameters
----------
dim, optional
dim along which FFT and IFFT are applied, by default last three dimensions (-1, -2, -3)
encoding_shape, optional
shape of encoded data
recon_shape, optional
shape of reconstructed data
"""
super().__init__()
self._dim: tuple[int, ...] = dim
self._pad_op: ZeroPadOp
if encoding_shape is not None and recon_shape is not None:
self._pad_op = ZeroPadOp(dim=dim, orig_shape=recon_shape, padded_shape=encoding_shape)
else:
# No padding
self._pad_op = ZeroPadOp(dim=(), orig_shape=(), padded_shape=())

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""FFT from image space to k-space.
Parameters
----------
x
image data on Cartesian grid
Returns
-------
FFT of x
"""
return torch.fft.fftshift(
torch.fft.fftn(torch.fft.ifftshift(self._pad_op.forward(x), dim=self._dim), dim=self._dim, norm='ortho'),
dim=self._dim,
)

def adjoint(self, y: torch.Tensor) -> torch.Tensor:
"""IFFT from k-space to image space.
Parameters
----------
y
k-space data on Cartesian grid
Returns
-------
IFFT of y
"""
# FFT
return self._pad_op.adjoint(
torch.fft.fftshift(
torch.fft.ifftn(torch.fft.ifftshift(y, dim=self._dim), dim=self._dim, norm='ortho'), dim=self._dim
)
)
50 changes: 19 additions & 31 deletions src/mrpro/operators/_FourierOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
# limitations under the License.

import torch
import torch.nn.functional as F
from einops import rearrange
from einops import repeat
from torchkbnufft import KbNufft
from torchkbnufft import KbNufftAdjoint

from mrpro.data import KTrajectory
from mrpro.data import SpatialDimension
from mrpro.operators import FastFourierOp
from mrpro.operators import LinearOperator


Expand Down Expand Up @@ -85,38 +85,36 @@ def __init__(
nufft_dims = []
fft_dims = []
ignore_dims = []
fft_s = []
fft_recon_shape = []
fft_encoding_shape = []
omega = []
traj_shape = []
super().__init__()

# create information about image shape, k-data shape etc
# and identify which directions can be ignored, which ones require a nuFFT
# and for which ones a simple FFT suffices
for n, os, k, i in zip(
for rs, es, os, k, i in zip(
(recon_shape.z, recon_shape.y, recon_shape.x),
(encoding_shape.z, encoding_shape.y, encoding_shape.x),
(oversampling.z, oversampling.y, oversampling.x),
(traj.kz, traj.ky, traj.kx),
(-3, -2, -1),
):
nk_list = [traj.kz.shape[i], traj.ky.shape[i], traj.kx.shape[i]]
if n <= 1 and nk_list.count(1) == 3:
if rs <= 1 and nk_list.count(1) == 3:
# dimension with no Fourier transform
ignore_dims.append(i)

elif nk_list.count(1) == 2: # and is_uniform(k): #TODO: maybe is_uniform never needed?
# dimension with FFT
nk = torch.tensor(nk_list)
supp = torch.where(nk != 1, nk, 0)
s = supp.max().item()

# append dimension and output shape for oversampled FFT
fft_dims.append(i)
fft_s.append(s)
fft_recon_shape.append(int(rs))
fft_encoding_shape.append(int(es))
else:
# dimension with nuFFT
grid_size.append(int(os * n))
nufft_im_size.append(n)
grid_size.append(int(os * rs))
nufft_im_size.append(rs)
nufft_dims.append(i)

# TODO: can omega be created here already?
Expand Down Expand Up @@ -148,11 +146,16 @@ def __init__(
self._ignore_dims = tuple(ignore_dims)
self._nufft_dims = tuple(nufft_dims)
self._fft_dims = tuple(fft_dims)
self._fft_s = tuple(fft_s)
self._fft_recon_shape = tuple(fft_recon_shape)
self._fft_encoding_shape = tuple(fft_encoding_shape)
self._kshape = torch.broadcast_shapes(*traj_shape)
self._recon_shape = recon_shape
self._nufft_im_size = nufft_im_size

self._fast_fourier_op = FastFourierOp(
dim=self._fft_dims, recon_shape=self._fft_recon_shape, encoding_shape=self._fft_encoding_shape
)

@staticmethod
def get_target_pattern(fft_dims: tuple[int, ...], nufft_dims: tuple[int, ...], ignore_dims: tuple[int, ...]) -> str:
"""Pattern to reshape image/k-space data to be able to perform nuFFT.
Expand Down Expand Up @@ -202,7 +205,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
raise ValueError('image data shape missmatch')

if len(self._fft_dims) != 0:
x = torch.fft.fftn(x, s=self._fft_s, dim=self._fft_dims, norm='ortho')
x = self._fast_fourier_op.forward(x)

if len(self._nufft_dims) != 0:
init_pattern = 'other coils dim2 dim1 dim0'
Expand Down Expand Up @@ -253,22 +256,7 @@ def adjoint(self, y: torch.Tensor) -> torch.Tensor:

# apply IFFT
if len(self._fft_dims) != 0:
recon_shape = [self._recon_shape.z, self._recon_shape.y, self._recon_shape.x]
y = torch.fft.ifftn(y, s=self._fft_s, dim=self._fft_dims, norm='ortho')

# construct the paddings based on the FFT-directions
# TODO: can this be written more nicely?
diff_dim = torch.tensor(recon_shape) - torch.tensor(y.shape[2:])
npad_tuple = tuple(
[
int(diff_dim[i // 2].item()) if (i % 2 == 0 and dim in self._fft_dims) else 0
for (i, dim) in zip(range(0, 6)[::-1], (-1, -1, -2, -2, -3, -3))
]
)

# crop using (negative) padding;
if not torch.all(torch.tensor(npad_tuple) != 0):
y = F.pad(y, npad_tuple)
y = self._fast_fourier_op.adjoint(y)

# move dim where FFT was already performed such nuFFT can be performed
if len(self._nufft_dims) != 0:
Expand Down Expand Up @@ -297,7 +285,7 @@ def adjoint(self, y: torch.Tensor) -> torch.Tensor:
y = y.contiguous() if y.stride()[-1] != 1 else y
y = self._adj_nufft_op(y, omega, norm='ortho')

# get back to orginal k-space shape
# get back to orginal image shape
nz, ny, nx = self._recon_shape.z, self._recon_shape.y, self._recon_shape.x
y = rearrange(y, target_pattern + '->' + init_pattern, other=nb, coils=nc, dim2=nz, dim1=ny, dim0=nx)

Expand Down
78 changes: 78 additions & 0 deletions src/mrpro/operators/_ZeroPadOp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Class for Zero Pad Operator."""

# Copyright 2023 Physikalisch-Technische Bundesanstalt
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from mrpro.operators import LinearOperator
from mrpro.utils import zero_pad_or_crop


class ZeroPadOp(LinearOperator):
"""Zero Pad operator class."""

def __init__(
self,
dim: tuple[int, ...],
orig_shape: tuple[int, ...],
padded_shape: tuple[int, ...],
) -> None:
"""Zero Pad Operator class.
The operator carries out zero-padding if the padded_shape is larger than orig_shape and cropping if the
padded_shape is smaller.
Parameters
----------
dim
dim along which padding should be applied
orig_shape
shape of original data along dim, same length as dim
padded_shape
shape of padded data along dim, same length as dim
"""
if len(dim) != len(orig_shape) or len(dim) != len(padded_shape):
raise ValueError('Dim, orig_shape and padded_shape have to be of same length')

super().__init__()
self.dim: tuple[int, ...] = dim
self.orig_shape: tuple[int, ...] = orig_shape
self.padded_shape: tuple[int, ...] = padded_shape

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Pad or crop data.
Parameters
----------
x
data with shape orig_shape
Returns
-------
data with shape padded_shape
"""
return zero_pad_or_crop(x, self.padded_shape, self.dim)

def adjoint(self, x: torch.Tensor) -> torch.Tensor:
"""Crop or pad data.
Parameters
----------
x
data with shape padded_shape
Returns
-------
data with shape orig_shape
"""
return zero_pad_or_crop(x, self.orig_shape, self.dim)
2 changes: 2 additions & 0 deletions src/mrpro/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from mrpro.operators._LinearOperator import LinearOperator
from mrpro.operators._NonLinearOperator import NonLinearOperator
from mrpro.operators._SensitivityOp import SensitivityOp
from mrpro.operators._ZeroPadOp import ZeroPadOp
from mrpro.operators._FastFourierOp import FastFourierOp
from mrpro.operators._FourierOp import FourierOp
from mrpro.operators.models._WASABI import WASABI
from mrpro.operators.models._WASABITI import WASABITI
1 change: 1 addition & 0 deletions src/mrpro/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from mrpro.utils._smap import smap
from mrpro.utils._remove_repeat import remove_repeat
from mrpro.utils._rgetattr import rgetattr
from mrpro.utils._zero_pad_or_crop import zero_pad_or_crop
Loading

0 comments on commit 93667b3

Please sign in to comment.