From 93667b3406341909c0bf3203a8b67f7509f52cf7 Mon Sep 17 00:00:00 2001 From: Christoph Kolbitsch Date: Fri, 26 Jan 2024 17:31:48 +0100 Subject: [PATCH] Fftshift for FourierOp (#143) * FastFourierOp and PadOp added --- src/mrpro/algorithms/_remove_readout_os.py | 8 +- src/mrpro/operators/_FastFourierOp.py | 90 +++++++++++++++++++ src/mrpro/operators/_FourierOp.py | 50 ++++------- src/mrpro/operators/_ZeroPadOp.py | 78 ++++++++++++++++ src/mrpro/operators/__init__.py | 2 + src/mrpro/utils/__init__.py | 1 + src/mrpro/utils/_zero_pad_or_crop.py | 89 ++++++++++++++++++ src/mrpro/utils/fft.py | 49 ---------- tests/algorithms/test_remove_readout_os.py | 5 +- tests/data/test_kdata.py | 5 +- .../test_fast_fourier_op.py} | 48 ++++++---- tests/operators/test_zero_pad_op.py | 59 ++++++++++++ tests/phantoms/test_ellipse_phantom.py | 5 +- tests/utils/test_zero_pad_or_crop.py | 30 +++++++ 14 files changed, 412 insertions(+), 107 deletions(-) create mode 100644 src/mrpro/operators/_FastFourierOp.py create mode 100644 src/mrpro/operators/_ZeroPadOp.py create mode 100644 src/mrpro/utils/_zero_pad_or_crop.py delete mode 100644 src/mrpro/utils/fft.py rename tests/{utils/test_fft.py => operators/test_fast_fourier_op.py} (51%) create mode 100644 tests/operators/test_zero_pad_op.py create mode 100644 tests/utils/test_zero_pad_or_crop.py diff --git a/src/mrpro/algorithms/_remove_readout_os.py b/src/mrpro/algorithms/_remove_readout_os.py index 7d098a75d..82b37a4e8 100644 --- a/src/mrpro/algorithms/_remove_readout_os.py +++ b/src/mrpro/algorithms/_remove_readout_os.py @@ -16,8 +16,7 @@ from mrpro.data import KData from mrpro.data import KTrajectory -from mrpro.utils.fft import image_to_kspace -from mrpro.utils.fft import kspace_to_image +from mrpro.operators import FastFourierOp def remove_readout_os(kdata: KData) -> KData: @@ -56,9 +55,10 @@ def crop_readout(input): return input[..., start_cropped_readout:end_cropped_readout] # Transform to image space, crop to reconstruction matrix size and transform back - dat = kspace_to_image(kdata.data, dim=(-1,)) + FFOp = FastFourierOp(dim=(-1,)) + dat = FFOp.adjoint(kdata.data) dat = crop_readout(dat) - dat = image_to_kspace(dat, dim=(-1,)) + dat = FFOp.forward(dat) # Adapt trajectory ks = [kdata.traj.kz, kdata.traj.ky, kdata.traj.kx] diff --git a/src/mrpro/operators/_FastFourierOp.py b/src/mrpro/operators/_FastFourierOp.py new file mode 100644 index 000000000..bbb2ae39f --- /dev/null +++ b/src/mrpro/operators/_FastFourierOp.py @@ -0,0 +1,90 @@ +"""Class for Fast Fourier Operator.""" + +# Copyright 2023 Physikalisch-Technische Bundesanstalt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from mrpro.operators import LinearOperator +from mrpro.operators import ZeroPadOp + + +class FastFourierOp(LinearOperator): + """Fast Fourier operator class.""" + + def __init__( + self, + dim: tuple[int, ...] = (-3, -2, -1), + recon_shape: tuple[int, ...] | None = None, + encoding_shape: tuple[int, ...] | None = None, + ) -> None: + """Fast Fourier Operator class. + + Remark regarding the fftshift/ifftshift: + fftshift shifts the zero-frequency point to the center of the data, ifftshift undoes this operation. + The input to both forward and ajoint assumes that the zero-frequency is in the center of the data. + Torch.fft.fftn and torch.fft.ifftn expect the zero-frequency to be the first entry in the tensor. + Therefore for forward and ajoint first ifftshift needs to be applied, then fftn or ifftn and then ifftshift. + + Parameters + ---------- + dim, optional + dim along which FFT and IFFT are applied, by default last three dimensions (-1, -2, -3) + encoding_shape, optional + shape of encoded data + recon_shape, optional + shape of reconstructed data + """ + super().__init__() + self._dim: tuple[int, ...] = dim + self._pad_op: ZeroPadOp + if encoding_shape is not None and recon_shape is not None: + self._pad_op = ZeroPadOp(dim=dim, orig_shape=recon_shape, padded_shape=encoding_shape) + else: + # No padding + self._pad_op = ZeroPadOp(dim=(), orig_shape=(), padded_shape=()) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """FFT from image space to k-space. + + Parameters + ---------- + x + image data on Cartesian grid + + Returns + ------- + FFT of x + """ + return torch.fft.fftshift( + torch.fft.fftn(torch.fft.ifftshift(self._pad_op.forward(x), dim=self._dim), dim=self._dim, norm='ortho'), + dim=self._dim, + ) + + def adjoint(self, y: torch.Tensor) -> torch.Tensor: + """IFFT from k-space to image space. + + Parameters + ---------- + y + k-space data on Cartesian grid + + Returns + ------- + IFFT of y + """ + # FFT + return self._pad_op.adjoint( + torch.fft.fftshift( + torch.fft.ifftn(torch.fft.ifftshift(y, dim=self._dim), dim=self._dim, norm='ortho'), dim=self._dim + ) + ) diff --git a/src/mrpro/operators/_FourierOp.py b/src/mrpro/operators/_FourierOp.py index b85b82afb..144bc3368 100644 --- a/src/mrpro/operators/_FourierOp.py +++ b/src/mrpro/operators/_FourierOp.py @@ -11,7 +11,6 @@ # limitations under the License. import torch -import torch.nn.functional as F from einops import rearrange from einops import repeat from torchkbnufft import KbNufft @@ -19,6 +18,7 @@ from mrpro.data import KTrajectory from mrpro.data import SpatialDimension +from mrpro.operators import FastFourierOp from mrpro.operators import LinearOperator @@ -85,7 +85,8 @@ def __init__( nufft_dims = [] fft_dims = [] ignore_dims = [] - fft_s = [] + fft_recon_shape = [] + fft_encoding_shape = [] omega = [] traj_shape = [] super().__init__() @@ -93,30 +94,27 @@ def __init__( # create information about image shape, k-data shape etc # and identify which directions can be ignored, which ones require a nuFFT # and for which ones a simple FFT suffices - for n, os, k, i in zip( + for rs, es, os, k, i in zip( (recon_shape.z, recon_shape.y, recon_shape.x), + (encoding_shape.z, encoding_shape.y, encoding_shape.x), (oversampling.z, oversampling.y, oversampling.x), (traj.kz, traj.ky, traj.kx), (-3, -2, -1), ): nk_list = [traj.kz.shape[i], traj.ky.shape[i], traj.kx.shape[i]] - if n <= 1 and nk_list.count(1) == 3: + if rs <= 1 and nk_list.count(1) == 3: # dimension with no Fourier transform ignore_dims.append(i) elif nk_list.count(1) == 2: # and is_uniform(k): #TODO: maybe is_uniform never needed? - # dimension with FFT - nk = torch.tensor(nk_list) - supp = torch.where(nk != 1, nk, 0) - s = supp.max().item() - # append dimension and output shape for oversampled FFT fft_dims.append(i) - fft_s.append(s) + fft_recon_shape.append(int(rs)) + fft_encoding_shape.append(int(es)) else: # dimension with nuFFT - grid_size.append(int(os * n)) - nufft_im_size.append(n) + grid_size.append(int(os * rs)) + nufft_im_size.append(rs) nufft_dims.append(i) # TODO: can omega be created here already? @@ -148,11 +146,16 @@ def __init__( self._ignore_dims = tuple(ignore_dims) self._nufft_dims = tuple(nufft_dims) self._fft_dims = tuple(fft_dims) - self._fft_s = tuple(fft_s) + self._fft_recon_shape = tuple(fft_recon_shape) + self._fft_encoding_shape = tuple(fft_encoding_shape) self._kshape = torch.broadcast_shapes(*traj_shape) self._recon_shape = recon_shape self._nufft_im_size = nufft_im_size + self._fast_fourier_op = FastFourierOp( + dim=self._fft_dims, recon_shape=self._fft_recon_shape, encoding_shape=self._fft_encoding_shape + ) + @staticmethod def get_target_pattern(fft_dims: tuple[int, ...], nufft_dims: tuple[int, ...], ignore_dims: tuple[int, ...]) -> str: """Pattern to reshape image/k-space data to be able to perform nuFFT. @@ -202,7 +205,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: raise ValueError('image data shape missmatch') if len(self._fft_dims) != 0: - x = torch.fft.fftn(x, s=self._fft_s, dim=self._fft_dims, norm='ortho') + x = self._fast_fourier_op.forward(x) if len(self._nufft_dims) != 0: init_pattern = 'other coils dim2 dim1 dim0' @@ -253,22 +256,7 @@ def adjoint(self, y: torch.Tensor) -> torch.Tensor: # apply IFFT if len(self._fft_dims) != 0: - recon_shape = [self._recon_shape.z, self._recon_shape.y, self._recon_shape.x] - y = torch.fft.ifftn(y, s=self._fft_s, dim=self._fft_dims, norm='ortho') - - # construct the paddings based on the FFT-directions - # TODO: can this be written more nicely? - diff_dim = torch.tensor(recon_shape) - torch.tensor(y.shape[2:]) - npad_tuple = tuple( - [ - int(diff_dim[i // 2].item()) if (i % 2 == 0 and dim in self._fft_dims) else 0 - for (i, dim) in zip(range(0, 6)[::-1], (-1, -1, -2, -2, -3, -3)) - ] - ) - - # crop using (negative) padding; - if not torch.all(torch.tensor(npad_tuple) != 0): - y = F.pad(y, npad_tuple) + y = self._fast_fourier_op.adjoint(y) # move dim where FFT was already performed such nuFFT can be performed if len(self._nufft_dims) != 0: @@ -297,7 +285,7 @@ def adjoint(self, y: torch.Tensor) -> torch.Tensor: y = y.contiguous() if y.stride()[-1] != 1 else y y = self._adj_nufft_op(y, omega, norm='ortho') - # get back to orginal k-space shape + # get back to orginal image shape nz, ny, nx = self._recon_shape.z, self._recon_shape.y, self._recon_shape.x y = rearrange(y, target_pattern + '->' + init_pattern, other=nb, coils=nc, dim2=nz, dim1=ny, dim0=nx) diff --git a/src/mrpro/operators/_ZeroPadOp.py b/src/mrpro/operators/_ZeroPadOp.py new file mode 100644 index 000000000..0fc3662a3 --- /dev/null +++ b/src/mrpro/operators/_ZeroPadOp.py @@ -0,0 +1,78 @@ +"""Class for Zero Pad Operator.""" + +# Copyright 2023 Physikalisch-Technische Bundesanstalt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from mrpro.operators import LinearOperator +from mrpro.utils import zero_pad_or_crop + + +class ZeroPadOp(LinearOperator): + """Zero Pad operator class.""" + + def __init__( + self, + dim: tuple[int, ...], + orig_shape: tuple[int, ...], + padded_shape: tuple[int, ...], + ) -> None: + """Zero Pad Operator class. + + The operator carries out zero-padding if the padded_shape is larger than orig_shape and cropping if the + padded_shape is smaller. + + Parameters + ---------- + dim + dim along which padding should be applied + orig_shape + shape of original data along dim, same length as dim + padded_shape + shape of padded data along dim, same length as dim + """ + if len(dim) != len(orig_shape) or len(dim) != len(padded_shape): + raise ValueError('Dim, orig_shape and padded_shape have to be of same length') + + super().__init__() + self.dim: tuple[int, ...] = dim + self.orig_shape: tuple[int, ...] = orig_shape + self.padded_shape: tuple[int, ...] = padded_shape + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Pad or crop data. + + Parameters + ---------- + x + data with shape orig_shape + + Returns + ------- + data with shape padded_shape + """ + return zero_pad_or_crop(x, self.padded_shape, self.dim) + + def adjoint(self, x: torch.Tensor) -> torch.Tensor: + """Crop or pad data. + + Parameters + ---------- + x + data with shape padded_shape + + Returns + ------- + data with shape orig_shape + """ + return zero_pad_or_crop(x, self.orig_shape, self.dim) diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index 271d93c09..ec6b4c5e0 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -2,6 +2,8 @@ from mrpro.operators._LinearOperator import LinearOperator from mrpro.operators._NonLinearOperator import NonLinearOperator from mrpro.operators._SensitivityOp import SensitivityOp +from mrpro.operators._ZeroPadOp import ZeroPadOp +from mrpro.operators._FastFourierOp import FastFourierOp from mrpro.operators._FourierOp import FourierOp from mrpro.operators.models._WASABI import WASABI from mrpro.operators.models._WASABITI import WASABITI diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index deaab59b9..63df948e2 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -1,3 +1,4 @@ from mrpro.utils._smap import smap from mrpro.utils._remove_repeat import remove_repeat from mrpro.utils._rgetattr import rgetattr +from mrpro.utils._zero_pad_or_crop import zero_pad_or_crop diff --git a/src/mrpro/utils/_zero_pad_or_crop.py b/src/mrpro/utils/_zero_pad_or_crop.py new file mode 100644 index 000000000..d33a96051 --- /dev/null +++ b/src/mrpro/utils/_zero_pad_or_crop.py @@ -0,0 +1,89 @@ +"""Zero pad and crop data tensor.""" + +# Copyright 2023 Physikalisch-Technische Bundesanstalt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn.functional as F + + +def normalize_index(ndim: int, index: int): + """Normalize possibly negative indices. + + Parameters + ---------- + ndim + number of dimensions + index + index to normalize. negative indices count from the end. + + Raises + ------ + IndexError + if index is outside [-ndim,ndim) + """ + if 0 < index < ndim: + return index + elif -ndim <= index < 0: + return ndim + index + else: + raise IndexError(f'Invalid index {index} for {ndim} data dimensions') + + +def zero_pad_or_crop( + data: torch.Tensor, new_shape: tuple[int, ...] | torch.Size, dim: None | tuple[int, ...] = None +) -> torch.Tensor: + """Change shape of data by cropping or zero-padding. + + Parameters + ---------- + data + data + new_shape + desired shape of data + dim: + dimensions the new_shape corresponds to. None (default) is interpreted as last len(new_shape) dimensions. + + Returns + ------- + data zero padded or cropped to shape + """ + + if len(new_shape) > data.ndim: + raise ValueError('length of new shape should not exceed dimensions of data') + + if dim is None: # Use last dimensions + new_shape = data.shape[: -len(new_shape)] + new_shape + else: + if len(new_shape) != len(dim): + raise ValueError('length of shape should match length of dim') + dim = tuple(normalize_index(data.ndim, idx) for idx in dim) # raises if any not in [-data.ndim,data.ndim) + if len(dim) != len(set(dim)): # this is why we normalize + raise ValueError('repeated values are not allowed in dims') + new_shape_full = torch.tensor(data.shape) + for i, d in enumerate(dim): + new_shape_full[d] = new_shape[i] + + npad = [] + for old, new in zip(torch.tensor(data.shape), new_shape_full): + diff = (new - old).numpy() # __trunc__ method not available for tensors + after = math.trunc(diff / 2) + before = diff - after + npad.append(before) + npad.append(after) + + if any(npad): + # F.pad expects paddings in reversed order + data = F.pad(data, npad[::-1]) + return data diff --git a/src/mrpro/utils/fft.py b/src/mrpro/utils/fft.py deleted file mode 100644 index 50fd216e1..000000000 --- a/src/mrpro/utils/fft.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Wrapper for FFT and IFFT.""" - -# Copyright 2023 Physikalisch-Technische Bundesanstalt -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - - -def kspace_to_image(kdat: torch.Tensor, dim: tuple[int, ...] = (-1, -2, -3)) -> torch.Tensor: - """IFFT from k-space to image space. - - Parameters - ---------- - kdat - k-space data on Cartesian grid - dim, optional - dim along which iFFT is applied, by default last three dimensions (-1, -2, -3) - - Returns - ------- - iFFT of kdat - """ - return torch.fft.fftshift(torch.fft.ifftn(torch.fft.ifftshift(kdat, dim=dim), dim=dim, norm='ortho'), dim=dim) - - -def image_to_kspace(idat: torch.Tensor, dim: tuple[int, ...] = (-1, -2, -3)) -> torch.Tensor: - """FFT from image space to k-space. - - Parameters - ---------- - idat - image data on Cartesian grid - dim, optional - dim along which FFT is applied, by default last three dimensions (-1, -2, -3) - - Returns - ------- - FFT of idat - """ - return torch.fft.ifftshift(torch.fft.fftn(torch.fft.fftshift(idat, dim=dim), dim=dim, norm='ortho'), dim=dim) diff --git a/tests/algorithms/test_remove_readout_os.py b/tests/algorithms/test_remove_readout_os.py index e06c7ae58..f7305b5df 100644 --- a/tests/algorithms/test_remove_readout_os.py +++ b/tests/algorithms/test_remove_readout_os.py @@ -19,7 +19,7 @@ from mrpro.data import KData from mrpro.data import KTrajectory from mrpro.data import SpatialDimension -from mrpro.utils.fft import kspace_to_image +from mrpro.operators import FastFourierOp from tests import RandomGenerator from tests.conftest import random_kheader from tests.helper import rel_image_diff @@ -74,7 +74,8 @@ def test_remove_readout_os(monkeypatch, random_kheader): kdata = remove_readout_os(kdata) # Reconstruct image from k-space data of one coil and compare to phantom image - idat_rec = kspace_to_image(kdata.data[:, 0, ...], dim=(-1, -2)) + FFOp = FastFourierOp(dim=(-1, -2)) + idat_rec = FFOp.adjoint(kdata.data[:, 0, ...]) # Due to discretisation artifacts the reconstructed image will be different to the reference image. Using standard # testing functions such as numpy.testing.assert_almost_equal fails because there are few voxels with high diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index fe2643ee4..0974be919 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -18,7 +18,7 @@ from mrpro.data import KData from mrpro.data import KTrajectory from mrpro.data.traj_calculators._KTrajectoryCalculator import DummyTrajectory -from mrpro.utils.fft import kspace_to_image +from mrpro.operators import FastFourierOp from tests.data import IsmrmrdRawTestData from tests.helper import rel_image_diff from tests.phantoms.test_ellipse_phantom import ph_ellipse @@ -66,7 +66,8 @@ def test_KData_from_file_diff_nky_for_rep(ismrmrd_cart_invalid_reps): def test_KData_kspace(ismrmrd_cart): """Read in data and verify k-space by comparing reconstructed image.""" k = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) - irec = kspace_to_image(k.data, dim=(-1, -2)) + FFOp = FastFourierOp(dim=(-1, -2)) + irec = FFOp.adjoint(k.data) # Due to discretisation artifacts the reconstructed image will be different to the reference image. Using standard # testing functions such as numpy.testing.assert_almost_equal fails because there are few voxels with high diff --git a/tests/utils/test_fft.py b/tests/operators/test_fast_fourier_op.py similarity index 51% rename from tests/utils/test_fft.py rename to tests/operators/test_fast_fourier_op.py index 36c8dc210..72c291ed5 100644 --- a/tests/utils/test_fft.py +++ b/tests/operators/test_fast_fourier_op.py @@ -1,4 +1,4 @@ -"""Tests for image space - k-space transformations.""" +"""Tests for Fast Fourier Operator class.""" # Copyright 2023 Physikalisch-Technische Bundesanstalt # @@ -16,13 +16,13 @@ import pytest import torch -from mrpro.utils.fft import image_to_kspace -from mrpro.utils.fft import kspace_to_image +from mrpro.operators import FastFourierOp +from tests import RandomGenerator @pytest.mark.parametrize('npoints, a', [(100, 20), (300, 20)]) -def test_kspace_to_image(npoints, a): - """Test k-space to image transformation using a Gaussian.""" +def test_fast_fourier_op_forward(npoints, a): + """Test Fast Fourier Op transformation using a Gaussian.""" # Utilize that a Fourier transform of a Gaussian function is given by # F(exp(-x^2/a)) = sqrt(pi*a)exp(-a*pi^2k^2) @@ -36,21 +36,35 @@ def test_kspace_to_image(npoints, a): igauss = torch.exp(-(x**2) / a).to(torch.complex64) kgauss = np.sqrt(torch.pi * a) * torch.exp(-a * torch.pi**2 * k**2).to(torch.complex64) - # Transform k-space to image - kgauss_fft = kspace_to_image(kgauss, dim=(0,)) + # Transform image to k-space + FFOp = FastFourierOp(dim=(0,)) + igauss_fwd = FFOp.forward(igauss) # Scaling to "undo" fft scaling - kgauss_fft *= 2 / np.sqrt(npoints) - torch.testing.assert_close(kgauss_fft, igauss) + igauss_fwd *= np.sqrt(npoints) / 2 + torch.testing.assert_close(igauss_fwd, kgauss) -def test_image_to_kspace_as_inverse(): - """Test if image_to_kspace is the inverse of kspace_to_image.""" +@pytest.mark.parametrize( + 'encoding_shape, recon_shape', + [ + ((101, 201, 50), (13, 221, 64)), + ((100, 200, 50), (14, 220, 64)), + ((101, 201, 50), (14, 220, 64)), + ((100, 200, 50), (13, 221, 64)), + ], +) +def test_fast_fourier_op_adjoint(encoding_shape, recon_shape): + """Test adjointness of Fast Fourier Op.""" - # Create random 3D data set - npoints = [200, 100, 50] - idat = torch.randn(*npoints, dtype=torch.complex64) + # Create test data + generator = RandomGenerator(seed=0) + x = generator.complex64_tensor(recon_shape) + y = generator.complex64_tensor(encoding_shape) - # Transform to k-space and back along all three dimensions - idat_transform = image_to_kspace(kspace_to_image(idat)) - torch.testing.assert_close(idat, idat_transform) + # Create operator and apply + FFOp = FastFourierOp(recon_shape=recon_shape, encoding_shape=encoding_shape) + Ax = FFOp.forward(x) + AHy = FFOp.adjoint(y) + + assert torch.isclose(torch.vdot(Ax.flatten(), y.flatten()), torch.vdot(x.flatten(), AHy.flatten()), rtol=1e-3) diff --git a/tests/operators/test_zero_pad_op.py b/tests/operators/test_zero_pad_op.py new file mode 100644 index 000000000..81d0c17e3 --- /dev/null +++ b/tests/operators/test_zero_pad_op.py @@ -0,0 +1,59 @@ +"""Tests for Zero Pad Operator class.""" + +# Copyright 2023 Physikalisch-Technische Bundesanstalt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch + +from mrpro.operators import ZeroPadOp +from tests import RandomGenerator + + +def test_zero_pad_op_content(): + """Test correct padding.""" + dshape_orig = (2, 100, 3, 200, 50, 2) + dshape_new = (2, 80, 3, 100, 240, 2) + generator = RandomGenerator(seed=0) + dorig = generator.complex64_tensor(dshape_orig) + pad_dim = (-5, -3, -2) + POp = ZeroPadOp( + dim=pad_dim, + orig_shape=tuple([dshape_orig[d] for d in pad_dim]), + padded_shape=tuple([dshape_new[d] for d in pad_dim]), + ) + dnew = POp.forward(dorig) + + # Compare overlapping region + torch.testing.assert_close(dorig[:, 10:90, :, 50:150, :, :], dnew[:, :, :, :, 95:145, :]) + + +@pytest.mark.parametrize( + 'u_shape, v_shape', + [ + ((101, 201, 50), (13, 221, 64)), + ((100, 200, 50), (14, 220, 64)), + ((101, 201, 50), (14, 220, 64)), + ((100, 200, 50), (13, 221, 64)), + ], +) +def test_zero_pad_op_ajoint(u_shape, v_shape): + """Test adjointness of pad operator.""" + generator = RandomGenerator(seed=0) + u = generator.complex64_tensor(u_shape) + v = generator.complex64_tensor(v_shape) + POp = ZeroPadOp(dim=(-3, -2, -1), orig_shape=u_shape, padded_shape=v_shape) + Au = POp.forward(u) + AHv = POp.adjoint(v) + + assert torch.isclose(torch.vdot(Au.flatten(), v.flatten()), torch.vdot(u.flatten(), AHv.flatten()), rtol=1e-3) diff --git a/tests/phantoms/test_ellipse_phantom.py b/tests/phantoms/test_ellipse_phantom.py index c7d710a66..f1a839c22 100644 --- a/tests/phantoms/test_ellipse_phantom.py +++ b/tests/phantoms/test_ellipse_phantom.py @@ -16,7 +16,7 @@ import torch from mrpro.data import SpatialDimension -from mrpro.utils.fft import kspace_to_image +from mrpro.operators import FastFourierOp from tests.helper import rel_image_diff from tests.phantoms._EllipsePhantomTestData import EllipsePhantomTestData @@ -55,7 +55,8 @@ def test_kspace_image_match(ph_ellipse): im_dim = SpatialDimension(z=1, y=ph_ellipse.ny, x=ph_ellipse.nx) im = ph_ellipse.phantom.image_space(im_dim) kdat = ph_ellipse.phantom.kspace(ph_ellipse.ky, ph_ellipse.kx) - irec = kspace_to_image(kdat, dim=(-1, -2)) + FFOp = FastFourierOp(dim=(-1, -2)) + irec = FFOp.adjoint(kdat) # Due to discretisation artifacts the reconstructed image will be different to the reference image. Using standard # testing functions such as numpy.testing.assert_almost_equal fails because there are few voxels with high # differences along the edges of the elliptic objects. diff --git a/tests/utils/test_zero_pad_or_crop.py b/tests/utils/test_zero_pad_or_crop.py new file mode 100644 index 000000000..790a52983 --- /dev/null +++ b/tests/utils/test_zero_pad_or_crop.py @@ -0,0 +1,30 @@ +"""Tests for zero padding and cropping of data tensors.""" + +# Copyright 2023 Physikalisch-Technische Bundesanstalt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from mrpro.utils._zero_pad_or_crop import zero_pad_or_crop +from tests import RandomGenerator + + +def test_zero_pad_or_crop_content(): + """Test changing data by cropping and padding.""" + generator = RandomGenerator(seed=0) + dshape_orig = (100, 200, 50) + dshape_new = (80, 100, 240) + dorig = generator.complex64_tensor(dshape_orig) + dnew = zero_pad_or_crop(dorig, dshape_new, dim=(-3, -2, -1)) + + # Compare overlapping region + torch.testing.assert_close(dorig[10:90, 50:150, :], dnew[:, :, 95:145])