Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fftshift for FourierOp #143

Merged
merged 5 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/mrpro/algorithms/_remove_readout_os.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

from mrpro.data import KData
from mrpro.data import KTrajectory
from mrpro.utils.fft import image_to_kspace
from mrpro.utils.fft import kspace_to_image
from mrpro.operators import FastFourierOp


def remove_readout_os(kdata: KData) -> KData:
Expand Down Expand Up @@ -56,9 +55,10 @@ def crop_readout(input):
return input[..., start_cropped_readout:end_cropped_readout]

# Transform to image space, crop to reconstruction matrix size and transform back
dat = kspace_to_image(kdata.data, dim=(-1,))
FFOp = FastFourierOp(dim=(-1,))
dat = FFOp.adjoint(kdata.data)
dat = crop_readout(dat)
dat = image_to_kspace(dat, dim=(-1,))
dat = FFOp.forward(dat)

# Adapt trajectory
ks = [kdata.traj.kz, kdata.traj.ky, kdata.traj.kx]
Expand Down
90 changes: 90 additions & 0 deletions src/mrpro/operators/_FastFourierOp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Class for Fast Fourier Operator."""

# Copyright 2023 Physikalisch-Technische Bundesanstalt
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from mrpro.operators import LinearOperator
from mrpro.operators import PadOp


class FastFourierOp(LinearOperator):
"""Fast Fourier operator class."""

def __init__(
self,
dim: tuple[int, ...] = (-3, -2, -1),
recon_shape: tuple[int, ...] | None = None,
encoding_shape: tuple[int, ...] | None = None,
) -> None:
"""Fast Fourier Operator class.

Remark regarding the fftshift/ifftshift:
fftshift shifts the zero-frequency point to the center of the data, ifftshift undoes this operation.
The input to both forward and ajoint assumes that the zero-frequency is in the center of the data.
Torch.fft.fftn and torch.fft.ifftn expect the zero-frequency to be the first entry in the tensor.
Therefore for forward and ajoint first ifftshift needs to be applied, then fftn or ifftn and then ifftshift.

Parameters
----------
dim, optional
dim along which FFT and IFFT are applied, by default last three dimensions (-1, -2, -3)
encoding_shape, optional
shape of encoded data
recon_shape, optional
shape of reconstructed data
"""
super().__init__()
self._dim: tuple[int, ...] = dim
self._pad_op: PadOp
if encoding_shape is not None and recon_shape is not None:
self._pad_op = PadOp(dim=dim, orig_shape=recon_shape, padded_shape=encoding_shape)
else:
# No padding
self._pad_op = PadOp(dim=(), orig_shape=(), padded_shape=())

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""FFT from image space to k-space.

Parameters
----------
x
image data on Cartesian grid

Returns
-------
FFT of x
"""
return torch.fft.fftshift(
torch.fft.fftn(torch.fft.ifftshift(self._pad_op.forward(x), dim=self._dim), dim=self._dim, norm='ortho'),
dim=self._dim,
)

def adjoint(self, y: torch.Tensor) -> torch.Tensor:
"""IFFT from k-space to image space.

Parameters
----------
y
k-space data on Cartesian grid

Returns
-------
IFFT of y
"""
# FFT
return self._pad_op.adjoint(
torch.fft.fftshift(
torch.fft.ifftn(torch.fft.ifftshift(y, dim=self._dim), dim=self._dim, norm='ortho'), dim=self._dim
)
)
50 changes: 19 additions & 31 deletions src/mrpro/operators/_FourierOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
# limitations under the License.

import torch
import torch.nn.functional as F
from einops import rearrange
from einops import repeat
from torchkbnufft import KbNufft
from torchkbnufft import KbNufftAdjoint

from mrpro.data import KTrajectory
from mrpro.data import SpatialDimension
from mrpro.operators import FastFourierOp
from mrpro.operators import LinearOperator


Expand Down Expand Up @@ -85,38 +85,36 @@ def __init__(
nufft_dims = []
fft_dims = []
ignore_dims = []
fft_s = []
fft_recon_shape = []
ckolbPTB marked this conversation as resolved.
Show resolved Hide resolved
fft_encoding_shape = []
omega = []
traj_shape = []
super().__init__()

# create information about image shape, k-data shape etc
# and identify which directions can be ignored, which ones require a nuFFT
# and for which ones a simple FFT suffices
for n, os, k, i in zip(
for rs, es, os, k, i in zip(
(recon_shape.z, recon_shape.y, recon_shape.x),
(encoding_shape.z, encoding_shape.y, encoding_shape.x),
(oversampling.z, oversampling.y, oversampling.x),
(traj.kz, traj.ky, traj.kx),
(-3, -2, -1),
):
nk_list = [traj.kz.shape[i], traj.ky.shape[i], traj.kx.shape[i]]
if n <= 1 and nk_list.count(1) == 3:
if rs <= 1 and nk_list.count(1) == 3:
# dimension with no Fourier transform
ignore_dims.append(i)

elif nk_list.count(1) == 2: # and is_uniform(k): #TODO: maybe is_uniform never needed?
# dimension with FFT
nk = torch.tensor(nk_list)
supp = torch.where(nk != 1, nk, 0)
s = supp.max().item()

# append dimension and output shape for oversampled FFT
fft_dims.append(i)
fft_s.append(s)
fft_recon_shape.append(int(rs))
fft_encoding_shape.append(int(es))
else:
# dimension with nuFFT
grid_size.append(int(os * n))
nufft_im_size.append(n)
grid_size.append(int(os * rs))
nufft_im_size.append(rs)
nufft_dims.append(i)

# TODO: can omega be created here already?
Expand Down Expand Up @@ -148,11 +146,16 @@ def __init__(
self._ignore_dims = tuple(ignore_dims)
self._nufft_dims = tuple(nufft_dims)
self._fft_dims = tuple(fft_dims)
self._fft_s = tuple(fft_s)
self._fft_recon_shape = tuple(fft_recon_shape)
self._fft_encoding_shape = tuple(fft_encoding_shape)
self._kshape = torch.broadcast_shapes(*traj_shape)
self._recon_shape = recon_shape
self._nufft_im_size = nufft_im_size

self._fast_fourier_op = FastFourierOp(
dim=self._fft_dims, recon_shape=self._fft_recon_shape, encoding_shape=self._fft_encoding_shape
)

@staticmethod
def get_target_pattern(fft_dims: tuple[int, ...], nufft_dims: tuple[int, ...], ignore_dims: tuple[int, ...]) -> str:
"""Pattern to reshape image/k-space data to be able to perform nuFFT.
Expand Down Expand Up @@ -202,7 +205,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
raise ValueError('image data shape missmatch')

if len(self._fft_dims) != 0:
x = torch.fft.fftn(x, s=self._fft_s, dim=self._fft_dims, norm='ortho')
x = self._fast_fourier_op.forward(x)

if len(self._nufft_dims) != 0:
init_pattern = 'other coils dim2 dim1 dim0'
Expand Down Expand Up @@ -253,22 +256,7 @@ def adjoint(self, y: torch.Tensor) -> torch.Tensor:

# apply IFFT
if len(self._fft_dims) != 0:
recon_shape = [self._recon_shape.z, self._recon_shape.y, self._recon_shape.x]
y = torch.fft.ifftn(y, s=self._fft_s, dim=self._fft_dims, norm='ortho')

# construct the paddings based on the FFT-directions
# TODO: can this be written more nicely?
diff_dim = torch.tensor(recon_shape) - torch.tensor(y.shape[2:])
npad_tuple = tuple(
[
int(diff_dim[i // 2].item()) if (i % 2 == 0 and dim in self._fft_dims) else 0
for (i, dim) in zip(range(0, 6)[::-1], (-1, -1, -2, -2, -3, -3))
]
)

# crop using (negative) padding;
if not torch.all(torch.tensor(npad_tuple) != 0):
y = F.pad(y, npad_tuple)
y = self._fast_fourier_op.adjoint(y)

# move dim where FFT was already performed such nuFFT can be performed
if len(self._nufft_dims) != 0:
Expand Down Expand Up @@ -297,7 +285,7 @@ def adjoint(self, y: torch.Tensor) -> torch.Tensor:
y = y.contiguous() if y.stride()[-1] != 1 else y
y = self._adj_nufft_op(y, omega, norm='ortho')

# get back to orginal k-space shape
# get back to orginal image shape
nz, ny, nx = self._recon_shape.z, self._recon_shape.y, self._recon_shape.x
y = rearrange(y, target_pattern + '->' + init_pattern, other=nb, coils=nc, dim2=nz, dim1=ny, dim0=nx)

Expand Down
103 changes: 103 additions & 0 deletions src/mrpro/operators/_PadOp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Class for Pad Operator."""

# Copyright 2023 Physikalisch-Technische Bundesanstalt
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from mrpro.operators import LinearOperator
from mrpro.utils import change_data_shape


class PadOp(LinearOperator):
"""Pad operator class."""

def __init__(
self,
dim: tuple[int, ...],
ckolbPTB marked this conversation as resolved.
Show resolved Hide resolved
orig_shape: tuple[int, ...],
padded_shape: tuple[int, ...],
) -> None:
"""Pad Operator class.

The operator carries out zero-padding if the padded_shape is larger than orig_shape and cropping if the
padded_shape is smaller.

Parameters
----------
dim
dim along which padding should be applied
orig_shape
shape of original data along dim, same length as dim
padded_shape
shape of padded data along dim, same length as dim
"""
if len(dim) != len(orig_shape) or len(dim) != len(padded_shape):
raise ValueError('Dim, orig_shape and padded_shape have to be of same length')

super().__init__()
self.dim: tuple[int, ...] = dim
self.orig_shape: tuple[int, ...] = orig_shape
self.padded_shape: tuple[int, ...] = padded_shape

@staticmethod
def _pad_data(x: torch.Tensor, dim: tuple[int, ...], padded_shape: tuple[int, ...]) -> torch.Tensor:
ckolbPTB marked this conversation as resolved.
Show resolved Hide resolved
"""Pad or crop data.

Parameters
----------
x
original data
dim
dim along which padding should be applied
padded_shape
shape of padded data

Returns
-------
data with shape padded_shape
"""
# Adapt image size
if len(dim) > 0:
s = list(x.shape)
for idx, idim in enumerate(dim):
s[idim] = padded_shape[idx]
x = change_data_shape(x, tuple(s))
return x

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Pad or crop data.

Parameters
----------
x
data with shape orig_shape

Returns
-------
data with shape padded_shape
"""
return self._pad_data(x, self.dim, self.padded_shape)

def adjoint(self, x: torch.Tensor) -> torch.Tensor:
"""Crop or pad data.

Parameters
----------
x
data with shape padded_shape

Returns
-------
data with shape orig_shape
"""
return self._pad_data(x, self.dim, self.orig_shape)
2 changes: 2 additions & 0 deletions src/mrpro/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from mrpro.operators._LinearOperator import LinearOperator
from mrpro.operators._NonLinearOperator import NonLinearOperator
from mrpro.operators._SensitivityOp import SensitivityOp
from mrpro.operators._PadOp import PadOp
from mrpro.operators._FastFourierOp import FastFourierOp
from mrpro.operators._FourierOp import FourierOp
from mrpro.operators.models._WASABI import WASABI
from mrpro.operators.models._WASABITI import WASABITI
1 change: 1 addition & 0 deletions src/mrpro/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from mrpro.utils._smap import smap
from mrpro.utils._remove_repeat import remove_repeat
from mrpro.utils._rgetattr import rgetattr
from mrpro.utils._change_data_shape import change_data_shape
Loading
Loading