From acc6fe66d25b2a7c3afc8b43cdfcab2382ff75a0 Mon Sep 17 00:00:00 2001 From: ckolbPTB Date: Fri, 22 Dec 2023 15:56:30 +0100 Subject: [PATCH 1/5] FastFourierOp and PadOp added --- src/mrpro/algorithms/_remove_readout_os.py | 8 +- src/mrpro/operators/_FastFourierOp.py | 84 ++++++++++++++++ src/mrpro/operators/_FourierOp.py | 50 ++++------ src/mrpro/operators/_PadOp.py | 99 +++++++++++++++++++ src/mrpro/operators/__init__.py | 2 + src/mrpro/utils/__init__.py | 1 + src/mrpro/utils/_change_data_shape.py | 49 +++++++++ src/mrpro/utils/fft.py | 49 --------- tests/algorithms/test_remove_readout_os.py | 5 +- tests/data/test_kdata.py | 5 +- .../test_fast_fourier_op.py} | 46 +++++---- tests/operators/test_pad_op.py | 56 +++++++++++ tests/phantoms/test_ellipse_phantom.py | 5 +- tests/utils/test_change_data_shape.py | 30 ++++++ 14 files changed, 382 insertions(+), 107 deletions(-) create mode 100644 src/mrpro/operators/_FastFourierOp.py create mode 100644 src/mrpro/operators/_PadOp.py create mode 100644 src/mrpro/utils/_change_data_shape.py delete mode 100644 src/mrpro/utils/fft.py rename tests/{utils/test_fft.py => operators/test_fast_fourier_op.py} (53%) create mode 100644 tests/operators/test_pad_op.py create mode 100644 tests/utils/test_change_data_shape.py diff --git a/src/mrpro/algorithms/_remove_readout_os.py b/src/mrpro/algorithms/_remove_readout_os.py index 7d098a75d..82b37a4e8 100644 --- a/src/mrpro/algorithms/_remove_readout_os.py +++ b/src/mrpro/algorithms/_remove_readout_os.py @@ -16,8 +16,7 @@ from mrpro.data import KData from mrpro.data import KTrajectory -from mrpro.utils.fft import image_to_kspace -from mrpro.utils.fft import kspace_to_image +from mrpro.operators import FastFourierOp def remove_readout_os(kdata: KData) -> KData: @@ -56,9 +55,10 @@ def crop_readout(input): return input[..., start_cropped_readout:end_cropped_readout] # Transform to image space, crop to reconstruction matrix size and transform back - dat = kspace_to_image(kdata.data, dim=(-1,)) + FFOp = FastFourierOp(dim=(-1,)) + dat = FFOp.adjoint(kdata.data) dat = crop_readout(dat) - dat = image_to_kspace(dat, dim=(-1,)) + dat = FFOp.forward(dat) # Adapt trajectory ks = [kdata.traj.kz, kdata.traj.ky, kdata.traj.kx] diff --git a/src/mrpro/operators/_FastFourierOp.py b/src/mrpro/operators/_FastFourierOp.py new file mode 100644 index 000000000..092e7eb82 --- /dev/null +++ b/src/mrpro/operators/_FastFourierOp.py @@ -0,0 +1,84 @@ +"""Class for Fast Fourier Operator.""" + +# Copyright 2023 Physikalisch-Technische Bundesanstalt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from mrpro.operators import LinearOperator +from mrpro.operators import PadOp + + +class FastFourierOp(LinearOperator): + """Fast Fourier operator class.""" + + def __init__( + self, + dim: tuple[int, ...] = (-3, -2, -1), + encoding_shape: tuple[int, ...] | None = None, + recon_shape: tuple[int, ...] | None = None, + ) -> None: + """Fast Fourier Operator class. + + Remark regarding the fftshift/ifftshift: + fftshift shifts the zero-frequency point to the center of the data, ifftshift undoes this operation. + The input to both forward and ajoint assumes that the zero-frequency is in the center of the data. + Torch.fft.fftn and torch.fft.ifftn expect the zero-frequency to be the first entry in the tensor. + Therefore for forward and ajoint first ifftshift needs to be applied, then fftn or ifftn and then ifftshift. + + Parameters + ---------- + dim, optional + dim along which FFT and IFFT are applied, by default last three dimensions (-1, -2, -3) + encoding_shape, optional + shape of encoded data + recon_shape, optional + shape of reconstructed data + """ + self._dim: tuple[int, ...] = dim + self._pad_op: PadOp = PadOp(dim=dim, orig_shape=encoding_shape, padded_shape=recon_shape) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """FFT from image space to k-space. + + Parameters + ---------- + x + image data on Cartesian grid + + Returns + ------- + FFT of x + """ + return torch.fft.fftshift( + torch.fft.fftn(torch.fft.ifftshift(self._pad_op.forward(x), dim=self._dim), dim=self._dim, norm='ortho'), + dim=self._dim, + ) + + def adjoint(self, y: torch.Tensor) -> torch.Tensor: + """IFFT from k-space to image space. + + Parameters + ---------- + y + k-space data on Cartesian grid + + Returns + ------- + IFFT of y + """ + # FFT + return self._pad_op.adjoint( + torch.fft.fftshift( + torch.fft.ifftn(torch.fft.ifftshift(y, dim=self._dim), dim=self._dim, norm='ortho'), dim=self._dim + ) + ) diff --git a/src/mrpro/operators/_FourierOp.py b/src/mrpro/operators/_FourierOp.py index b85b82afb..530874713 100644 --- a/src/mrpro/operators/_FourierOp.py +++ b/src/mrpro/operators/_FourierOp.py @@ -11,7 +11,6 @@ # limitations under the License. import torch -import torch.nn.functional as F from einops import rearrange from einops import repeat from torchkbnufft import KbNufft @@ -19,6 +18,7 @@ from mrpro.data import KTrajectory from mrpro.data import SpatialDimension +from mrpro.operators import FastFourierOp from mrpro.operators import LinearOperator @@ -85,7 +85,8 @@ def __init__( nufft_dims = [] fft_dims = [] ignore_dims = [] - fft_s = [] + fft_recon_shape = [] + fft_encoding_shape = [] omega = [] traj_shape = [] super().__init__() @@ -93,30 +94,27 @@ def __init__( # create information about image shape, k-data shape etc # and identify which directions can be ignored, which ones require a nuFFT # and for which ones a simple FFT suffices - for n, os, k, i in zip( + for rs, es, os, k, i in zip( (recon_shape.z, recon_shape.y, recon_shape.x), + (encoding_shape.z, encoding_shape.y, encoding_shape.x), (oversampling.z, oversampling.y, oversampling.x), (traj.kz, traj.ky, traj.kx), (-3, -2, -1), ): nk_list = [traj.kz.shape[i], traj.ky.shape[i], traj.kx.shape[i]] - if n <= 1 and nk_list.count(1) == 3: + if rs <= 1 and nk_list.count(1) == 3: # dimension with no Fourier transform ignore_dims.append(i) elif nk_list.count(1) == 2: # and is_uniform(k): #TODO: maybe is_uniform never needed? - # dimension with FFT - nk = torch.tensor(nk_list) - supp = torch.where(nk != 1, nk, 0) - s = supp.max().item() - # append dimension and output shape for oversampled FFT fft_dims.append(i) - fft_s.append(s) + fft_recon_shape.append(int(rs)) + fft_encoding_shape.append(int(es)) else: # dimension with nuFFT - grid_size.append(int(os * n)) - nufft_im_size.append(n) + grid_size.append(int(os * rs)) + nufft_im_size.append(rs) nufft_dims.append(i) # TODO: can omega be created here already? @@ -148,11 +146,16 @@ def __init__( self._ignore_dims = tuple(ignore_dims) self._nufft_dims = tuple(nufft_dims) self._fft_dims = tuple(fft_dims) - self._fft_s = tuple(fft_s) + self._fft_recon_shape = tuple(fft_recon_shape) + self._fft_encoding_shape = tuple(fft_encoding_shape) self._kshape = torch.broadcast_shapes(*traj_shape) self._recon_shape = recon_shape self._nufft_im_size = nufft_im_size + self._fast_fourier_op = FastFourierOp( + dim=self._fft_dims, encoding_shape=self._fft_encoding_shape, recon_shape=self._fft_recon_shape + ) + @staticmethod def get_target_pattern(fft_dims: tuple[int, ...], nufft_dims: tuple[int, ...], ignore_dims: tuple[int, ...]) -> str: """Pattern to reshape image/k-space data to be able to perform nuFFT. @@ -202,7 +205,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: raise ValueError('image data shape missmatch') if len(self._fft_dims) != 0: - x = torch.fft.fftn(x, s=self._fft_s, dim=self._fft_dims, norm='ortho') + x = self._fast_fourier_op.forward(x) if len(self._nufft_dims) != 0: init_pattern = 'other coils dim2 dim1 dim0' @@ -253,22 +256,7 @@ def adjoint(self, y: torch.Tensor) -> torch.Tensor: # apply IFFT if len(self._fft_dims) != 0: - recon_shape = [self._recon_shape.z, self._recon_shape.y, self._recon_shape.x] - y = torch.fft.ifftn(y, s=self._fft_s, dim=self._fft_dims, norm='ortho') - - # construct the paddings based on the FFT-directions - # TODO: can this be written more nicely? - diff_dim = torch.tensor(recon_shape) - torch.tensor(y.shape[2:]) - npad_tuple = tuple( - [ - int(diff_dim[i // 2].item()) if (i % 2 == 0 and dim in self._fft_dims) else 0 - for (i, dim) in zip(range(0, 6)[::-1], (-1, -1, -2, -2, -3, -3)) - ] - ) - - # crop using (negative) padding; - if not torch.all(torch.tensor(npad_tuple) != 0): - y = F.pad(y, npad_tuple) + y = self._fast_fourier_op.adjoint(y) # move dim where FFT was already performed such nuFFT can be performed if len(self._nufft_dims) != 0: @@ -297,7 +285,7 @@ def adjoint(self, y: torch.Tensor) -> torch.Tensor: y = y.contiguous() if y.stride()[-1] != 1 else y y = self._adj_nufft_op(y, omega, norm='ortho') - # get back to orginal k-space shape + # get back to orginal image shape nz, ny, nx = self._recon_shape.z, self._recon_shape.y, self._recon_shape.x y = rearrange(y, target_pattern + '->' + init_pattern, other=nb, coils=nc, dim2=nz, dim1=ny, dim0=nx) diff --git a/src/mrpro/operators/_PadOp.py b/src/mrpro/operators/_PadOp.py new file mode 100644 index 000000000..e8b4db1f2 --- /dev/null +++ b/src/mrpro/operators/_PadOp.py @@ -0,0 +1,99 @@ +"""Class for Pad Operator.""" + +# Copyright 2023 Physikalisch-Technische Bundesanstalt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from mrpro.operators import LinearOperator +from mrpro.utils import change_data_shape + + +class PadOp(LinearOperator): + """Pad operator class.""" + + def __init__( + self, + dim: tuple[int, ...] = (-3, -2, -1), + orig_shape: tuple[int, ...] | None = None, + padded_shape: tuple[int, ...] | None = None, + ) -> None: + """Pad Operator class. + + The operator carries out zero-padding if the padded_shape is larger than orig_shape and cropping if the + padded_shape is smaller. + + Parameters + ---------- + dim, optional + dim along which padding should be applied, by default last three dimensions (-1, -2, -3) + orig_shape, optional + shape of original data along dim + padded_shape, optional + shape of padded data along dim + """ + self.dim: tuple[int, ...] = dim + self.orig_shape: tuple[int, ...] | None = orig_shape + self.padded_shape: tuple[int, ...] | None = padded_shape + + @staticmethod + def _pad_data(x: torch.Tensor, dim: tuple[int, ...], padded_shape: tuple[int, ...] | None) -> torch.Tensor: + """Pad or crop data. + + Parameters + ---------- + x + original data + dim + dim along which padding should be applied + padded_shape + shape of padded data + + Returns + ------- + data with shape padded_shape + """ + # Adapt image size + if padded_shape is not None: + s = list(x.shape) + for idx, idim in enumerate(dim): + s[idim] = padded_shape[idx] + x = change_data_shape(x, tuple(s)) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Pad or crop data. + + Parameters + ---------- + x + original data with shape orig_shape + + Returns + ------- + data with shape padded_shape + """ + return self._pad_data(x, self.dim, self.padded_shape) + + def adjoint(self, x: torch.Tensor) -> torch.Tensor: + """Crop or pad data. + + Parameters + ---------- + x + original data with shape padded_shape + + Returns + ------- + data with shape orig_shape + """ + return self._pad_data(x, self.dim, self.orig_shape) diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index 271d93c09..041dc5de1 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -2,6 +2,8 @@ from mrpro.operators._LinearOperator import LinearOperator from mrpro.operators._NonLinearOperator import NonLinearOperator from mrpro.operators._SensitivityOp import SensitivityOp +from mrpro.operators._PadOp import PadOp +from mrpro.operators._FastFourierOp import FastFourierOp from mrpro.operators._FourierOp import FourierOp from mrpro.operators.models._WASABI import WASABI from mrpro.operators.models._WASABITI import WASABITI diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index deaab59b9..eaf29570a 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -1,3 +1,4 @@ from mrpro.utils._smap import smap from mrpro.utils._remove_repeat import remove_repeat from mrpro.utils._rgetattr import rgetattr +from mrpro.utils._change_data_shape import change_data_shape diff --git a/src/mrpro/utils/_change_data_shape.py b/src/mrpro/utils/_change_data_shape.py new file mode 100644 index 000000000..85d7dbe23 --- /dev/null +++ b/src/mrpro/utils/_change_data_shape.py @@ -0,0 +1,49 @@ +"""Wrapper for FFT and IFFT.""" + +# Copyright 2023 Physikalisch-Technische Bundesanstalt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn.functional as F + + +def change_data_shape(dat: torch.Tensor, dat_shape_new: tuple[int, ...]) -> torch.Tensor: + """Change shape of data by cropping or zero-padding. + + Parameters + ---------- + dat + data + dat_shape_new + desired shape of data + + Returns + ------- + data with shape dat_shape_new + """ + s = list(dat.shape) + # Padding + npad = [0] * (2 * len(s)) + + for idx in range(len(s)): + if s[idx] != dat_shape_new[idx]: + dim_diff = dat_shape_new[idx] - s[idx] + # This is needed to ensure that padding and cropping leads to the same asymetry for odd shape differences + npad[2 * idx] = np.sign(dim_diff) * (np.abs(dim_diff) // 2) + npad[2 * idx + 1] = dat_shape_new[idx] - (s[idx] + npad[2 * idx]) + + # Pad (positive npad) or crop (negative npad) + # npad has to be reversed because pad expects it in reversed order + if not torch.all(torch.tensor(npad) == 0): + dat = F.pad(dat, npad[::-1]) + return dat diff --git a/src/mrpro/utils/fft.py b/src/mrpro/utils/fft.py deleted file mode 100644 index 50fd216e1..000000000 --- a/src/mrpro/utils/fft.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Wrapper for FFT and IFFT.""" - -# Copyright 2023 Physikalisch-Technische Bundesanstalt -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - - -def kspace_to_image(kdat: torch.Tensor, dim: tuple[int, ...] = (-1, -2, -3)) -> torch.Tensor: - """IFFT from k-space to image space. - - Parameters - ---------- - kdat - k-space data on Cartesian grid - dim, optional - dim along which iFFT is applied, by default last three dimensions (-1, -2, -3) - - Returns - ------- - iFFT of kdat - """ - return torch.fft.fftshift(torch.fft.ifftn(torch.fft.ifftshift(kdat, dim=dim), dim=dim, norm='ortho'), dim=dim) - - -def image_to_kspace(idat: torch.Tensor, dim: tuple[int, ...] = (-1, -2, -3)) -> torch.Tensor: - """FFT from image space to k-space. - - Parameters - ---------- - idat - image data on Cartesian grid - dim, optional - dim along which FFT is applied, by default last three dimensions (-1, -2, -3) - - Returns - ------- - FFT of idat - """ - return torch.fft.ifftshift(torch.fft.fftn(torch.fft.fftshift(idat, dim=dim), dim=dim, norm='ortho'), dim=dim) diff --git a/tests/algorithms/test_remove_readout_os.py b/tests/algorithms/test_remove_readout_os.py index e06c7ae58..f7305b5df 100644 --- a/tests/algorithms/test_remove_readout_os.py +++ b/tests/algorithms/test_remove_readout_os.py @@ -19,7 +19,7 @@ from mrpro.data import KData from mrpro.data import KTrajectory from mrpro.data import SpatialDimension -from mrpro.utils.fft import kspace_to_image +from mrpro.operators import FastFourierOp from tests import RandomGenerator from tests.conftest import random_kheader from tests.helper import rel_image_diff @@ -74,7 +74,8 @@ def test_remove_readout_os(monkeypatch, random_kheader): kdata = remove_readout_os(kdata) # Reconstruct image from k-space data of one coil and compare to phantom image - idat_rec = kspace_to_image(kdata.data[:, 0, ...], dim=(-1, -2)) + FFOp = FastFourierOp(dim=(-1, -2)) + idat_rec = FFOp.adjoint(kdata.data[:, 0, ...]) # Due to discretisation artifacts the reconstructed image will be different to the reference image. Using standard # testing functions such as numpy.testing.assert_almost_equal fails because there are few voxels with high diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index fe2643ee4..0974be919 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -18,7 +18,7 @@ from mrpro.data import KData from mrpro.data import KTrajectory from mrpro.data.traj_calculators._KTrajectoryCalculator import DummyTrajectory -from mrpro.utils.fft import kspace_to_image +from mrpro.operators import FastFourierOp from tests.data import IsmrmrdRawTestData from tests.helper import rel_image_diff from tests.phantoms.test_ellipse_phantom import ph_ellipse @@ -66,7 +66,8 @@ def test_KData_from_file_diff_nky_for_rep(ismrmrd_cart_invalid_reps): def test_KData_kspace(ismrmrd_cart): """Read in data and verify k-space by comparing reconstructed image.""" k = KData.from_file(ismrmrd_cart.filename, DummyTrajectory()) - irec = kspace_to_image(k.data, dim=(-1, -2)) + FFOp = FastFourierOp(dim=(-1, -2)) + irec = FFOp.adjoint(k.data) # Due to discretisation artifacts the reconstructed image will be different to the reference image. Using standard # testing functions such as numpy.testing.assert_almost_equal fails because there are few voxels with high diff --git a/tests/utils/test_fft.py b/tests/operators/test_fast_fourier_op.py similarity index 53% rename from tests/utils/test_fft.py rename to tests/operators/test_fast_fourier_op.py index 36c8dc210..2c2d1e020 100644 --- a/tests/utils/test_fft.py +++ b/tests/operators/test_fast_fourier_op.py @@ -1,4 +1,4 @@ -"""Tests for image space - k-space transformations.""" +"""Tests for Fast Fourier Operator class.""" # Copyright 2023 Physikalisch-Technische Bundesanstalt # @@ -16,13 +16,12 @@ import pytest import torch -from mrpro.utils.fft import image_to_kspace -from mrpro.utils.fft import kspace_to_image +from mrpro.operators import FastFourierOp @pytest.mark.parametrize('npoints, a', [(100, 20), (300, 20)]) -def test_kspace_to_image(npoints, a): - """Test k-space to image transformation using a Gaussian.""" +def test_fast_fourier_op_forward(npoints, a): + """Test Fast Fourier Op transformation using a Gaussian.""" # Utilize that a Fourier transform of a Gaussian function is given by # F(exp(-x^2/a)) = sqrt(pi*a)exp(-a*pi^2k^2) @@ -36,21 +35,34 @@ def test_kspace_to_image(npoints, a): igauss = torch.exp(-(x**2) / a).to(torch.complex64) kgauss = np.sqrt(torch.pi * a) * torch.exp(-a * torch.pi**2 * k**2).to(torch.complex64) - # Transform k-space to image - kgauss_fft = kspace_to_image(kgauss, dim=(0,)) + # Transform image to k-space + FFOp = FastFourierOp(dim=(0,)) + igauss_fwd = FFOp.forward(igauss) # Scaling to "undo" fft scaling - kgauss_fft *= 2 / np.sqrt(npoints) - torch.testing.assert_close(kgauss_fft, igauss) + igauss_fwd *= np.sqrt(npoints) / 2 + torch.testing.assert_close(igauss_fwd, kgauss) -def test_image_to_kspace_as_inverse(): - """Test if image_to_kspace is the inverse of kspace_to_image.""" +@pytest.mark.parametrize( + 'encoding_shape, recon_shape', + [ + ((101, 201, 50), (13, 221, 64)), + ((100, 200, 50), (14, 220, 64)), + ((101, 201, 50), (14, 220, 64)), + ((100, 200, 50), (13, 221, 64)), + ], +) +def test_fast_fourier_op_adjoint(encoding_shape, recon_shape): + """Test adjointness of Fast Fourier Op.""" - # Create random 3D data set - npoints = [200, 100, 50] - idat = torch.randn(*npoints, dtype=torch.complex64) + # Create test data + x = torch.randn(recon_shape, dtype=torch.complex64) + y = torch.randn(encoding_shape, dtype=torch.complex64) - # Transform to k-space and back along all three dimensions - idat_transform = image_to_kspace(kspace_to_image(idat)) - torch.testing.assert_close(idat, idat_transform) + # Create operator and apply + FFOp = FastFourierOp(encoding_shape=encoding_shape, recon_shape=recon_shape) + Ax = FFOp.forward(x) + AHy = FFOp.adjoint(y) + + assert torch.isclose(torch.vdot(Ax.flatten(), y.flatten()), torch.vdot(x.flatten(), AHy.flatten()), rtol=1e-3) diff --git a/tests/operators/test_pad_op.py b/tests/operators/test_pad_op.py new file mode 100644 index 000000000..c70928835 --- /dev/null +++ b/tests/operators/test_pad_op.py @@ -0,0 +1,56 @@ +"""Tests for Pad Operator class.""" + +# Copyright 2023 Physikalisch-Technische Bundesanstalt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch + +from mrpro.operators import PadOp + + +def test_pad_op_content(): + """Test correct padding.""" + dshape_orig = (2, 100, 3, 200, 50, 2) + dshape_new = (2, 80, 3, 100, 240, 2) + dorig = torch.randn(*dshape_orig, dtype=torch.complex64) + pad_dim = (-5, -3, -2) + POp = PadOp( + dim=pad_dim, + orig_shape=tuple([dshape_orig[d] for d in pad_dim]), + padded_shape=tuple([dshape_new[d] for d in pad_dim]), + ) + dnew = POp.forward(dorig) + + # Compare overlapping region + torch.testing.assert_close(dorig[:, 10:90, :, 50:150, :, :], dnew[:, :, :, :, 95:145, :]) + + +@pytest.mark.parametrize( + 'u_shape, v_shape', + [ + ((101, 201, 50), (13, 221, 64)), + ((100, 200, 50), (14, 220, 64)), + ((101, 201, 50), (14, 220, 64)), + ((100, 200, 50), (13, 221, 64)), + ], +) +def test_pad_op_ajoint(u_shape, v_shape): + """Test adjointness of pad operator.""" + u = torch.randn(u_shape, dtype=torch.complex64) + v = torch.randn(v_shape, dtype=torch.complex64) + POp = PadOp(orig_shape=u_shape, padded_shape=v_shape) + Au = POp.forward(u) + AHv = POp.adjoint(v) + + assert torch.isclose(torch.vdot(Au.flatten(), v.flatten()), torch.vdot(u.flatten(), AHv.flatten()), rtol=1e-3) diff --git a/tests/phantoms/test_ellipse_phantom.py b/tests/phantoms/test_ellipse_phantom.py index c7d710a66..f1a839c22 100644 --- a/tests/phantoms/test_ellipse_phantom.py +++ b/tests/phantoms/test_ellipse_phantom.py @@ -16,7 +16,7 @@ import torch from mrpro.data import SpatialDimension -from mrpro.utils.fft import kspace_to_image +from mrpro.operators import FastFourierOp from tests.helper import rel_image_diff from tests.phantoms._EllipsePhantomTestData import EllipsePhantomTestData @@ -55,7 +55,8 @@ def test_kspace_image_match(ph_ellipse): im_dim = SpatialDimension(z=1, y=ph_ellipse.ny, x=ph_ellipse.nx) im = ph_ellipse.phantom.image_space(im_dim) kdat = ph_ellipse.phantom.kspace(ph_ellipse.ky, ph_ellipse.kx) - irec = kspace_to_image(kdat, dim=(-1, -2)) + FFOp = FastFourierOp(dim=(-1, -2)) + irec = FFOp.adjoint(kdat) # Due to discretisation artifacts the reconstructed image will be different to the reference image. Using standard # testing functions such as numpy.testing.assert_almost_equal fails because there are few voxels with high # differences along the edges of the elliptic objects. diff --git a/tests/utils/test_change_data_shape.py b/tests/utils/test_change_data_shape.py new file mode 100644 index 000000000..e19493317 --- /dev/null +++ b/tests/utils/test_change_data_shape.py @@ -0,0 +1,30 @@ +"""Tests for image space - k-space transformations.""" + +# Copyright 2023 Physikalisch-Technische Bundesanstalt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch + +from mrpro.utils._change_data_shape import change_data_shape + + +def test_change_data_shape_content(): + """Test changing data size by cropping and padding.""" + dshape_orig = (100, 200, 50) + dshape_new = (80, 100, 240) + dorig = torch.randn(*dshape_orig, dtype=torch.complex64) + dnew = change_data_shape(dorig, dshape_new) + + # Compare overlapping region + torch.testing.assert_close(dorig[10:90, 50:150, :], dnew[:, :, 95:145]) From 42a8f0e12cc1e56802a80f3817503916bfeddf4d Mon Sep 17 00:00:00 2001 From: ckolbPTB Date: Thu, 25 Jan 2024 10:27:14 +0100 Subject: [PATCH 2/5] PadOp without None --- src/mrpro/operators/_FastFourierOp.py | 8 ++++++- src/mrpro/operators/_PadOp.py | 34 +++++++++++++++------------ tests/operators/test_pad_op.py | 2 +- 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/src/mrpro/operators/_FastFourierOp.py b/src/mrpro/operators/_FastFourierOp.py index 092e7eb82..e7a506341 100644 --- a/src/mrpro/operators/_FastFourierOp.py +++ b/src/mrpro/operators/_FastFourierOp.py @@ -44,8 +44,14 @@ def __init__( recon_shape, optional shape of reconstructed data """ + super().__init__() self._dim: tuple[int, ...] = dim - self._pad_op: PadOp = PadOp(dim=dim, orig_shape=encoding_shape, padded_shape=recon_shape) + self._pad_op: PadOp + if encoding_shape is not None and recon_shape is not None: + self._pad_op = PadOp(dim=dim, orig_shape=recon_shape, padded_shape=encoding_shape) + else: + # No padding + self._pad_op = PadOp(dim=(), orig_shape=(), padded_shape=()) def forward(self, x: torch.Tensor) -> torch.Tensor: """FFT from image space to k-space. diff --git a/src/mrpro/operators/_PadOp.py b/src/mrpro/operators/_PadOp.py index e8b4db1f2..5d5efb6f5 100644 --- a/src/mrpro/operators/_PadOp.py +++ b/src/mrpro/operators/_PadOp.py @@ -23,9 +23,9 @@ class PadOp(LinearOperator): def __init__( self, - dim: tuple[int, ...] = (-3, -2, -1), - orig_shape: tuple[int, ...] | None = None, - padded_shape: tuple[int, ...] | None = None, + dim: tuple[int, ...], + orig_shape: tuple[int, ...], + padded_shape: tuple[int, ...], ) -> None: """Pad Operator class. @@ -34,19 +34,23 @@ def __init__( Parameters ---------- - dim, optional - dim along which padding should be applied, by default last three dimensions (-1, -2, -3) - orig_shape, optional - shape of original data along dim - padded_shape, optional - shape of padded data along dim + dim + dim along which padding should be applied + orig_shape + shape of original data along dim, same length as dim + padded_shape + shape of padded data along dim, same length as dim """ + if len(dim) != len(orig_shape) or len(dim) != len(padded_shape): + raise ValueError('Dim, orig_shape and padded_shape have to be of same length') + + super().__init__() self.dim: tuple[int, ...] = dim - self.orig_shape: tuple[int, ...] | None = orig_shape - self.padded_shape: tuple[int, ...] | None = padded_shape + self.orig_shape: tuple[int, ...] = orig_shape + self.padded_shape: tuple[int, ...] = padded_shape @staticmethod - def _pad_data(x: torch.Tensor, dim: tuple[int, ...], padded_shape: tuple[int, ...] | None) -> torch.Tensor: + def _pad_data(x: torch.Tensor, dim: tuple[int, ...], padded_shape: tuple[int, ...]) -> torch.Tensor: """Pad or crop data. Parameters @@ -63,7 +67,7 @@ def _pad_data(x: torch.Tensor, dim: tuple[int, ...], padded_shape: tuple[int, .. data with shape padded_shape """ # Adapt image size - if padded_shape is not None: + if len(dim) > 0: s = list(x.shape) for idx, idim in enumerate(dim): s[idim] = padded_shape[idx] @@ -76,7 +80,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Parameters ---------- x - original data with shape orig_shape + data with shape orig_shape Returns ------- @@ -90,7 +94,7 @@ def adjoint(self, x: torch.Tensor) -> torch.Tensor: Parameters ---------- x - original data with shape padded_shape + data with shape padded_shape Returns ------- diff --git a/tests/operators/test_pad_op.py b/tests/operators/test_pad_op.py index c70928835..7f40db45e 100644 --- a/tests/operators/test_pad_op.py +++ b/tests/operators/test_pad_op.py @@ -49,7 +49,7 @@ def test_pad_op_ajoint(u_shape, v_shape): """Test adjointness of pad operator.""" u = torch.randn(u_shape, dtype=torch.complex64) v = torch.randn(v_shape, dtype=torch.complex64) - POp = PadOp(orig_shape=u_shape, padded_shape=v_shape) + POp = PadOp(dim=(-3, -2, -1), orig_shape=u_shape, padded_shape=v_shape) Au = POp.forward(u) AHv = POp.adjoint(v) From 11eacb9afd0b2b75b61e0665b757b73bd5864ef3 Mon Sep 17 00:00:00 2001 From: ckolbPTB Date: Thu, 25 Jan 2024 10:34:57 +0100 Subject: [PATCH 3/5] ensure recon and encoding shape is provided same as for FourierOp --- src/mrpro/operators/_FastFourierOp.py | 2 +- src/mrpro/operators/_FourierOp.py | 2 +- tests/operators/test_fast_fourier_op.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mrpro/operators/_FastFourierOp.py b/src/mrpro/operators/_FastFourierOp.py index e7a506341..28eca44d5 100644 --- a/src/mrpro/operators/_FastFourierOp.py +++ b/src/mrpro/operators/_FastFourierOp.py @@ -24,8 +24,8 @@ class FastFourierOp(LinearOperator): def __init__( self, dim: tuple[int, ...] = (-3, -2, -1), - encoding_shape: tuple[int, ...] | None = None, recon_shape: tuple[int, ...] | None = None, + encoding_shape: tuple[int, ...] | None = None, ) -> None: """Fast Fourier Operator class. diff --git a/src/mrpro/operators/_FourierOp.py b/src/mrpro/operators/_FourierOp.py index 530874713..144bc3368 100644 --- a/src/mrpro/operators/_FourierOp.py +++ b/src/mrpro/operators/_FourierOp.py @@ -153,7 +153,7 @@ def __init__( self._nufft_im_size = nufft_im_size self._fast_fourier_op = FastFourierOp( - dim=self._fft_dims, encoding_shape=self._fft_encoding_shape, recon_shape=self._fft_recon_shape + dim=self._fft_dims, recon_shape=self._fft_recon_shape, encoding_shape=self._fft_encoding_shape ) @staticmethod diff --git a/tests/operators/test_fast_fourier_op.py b/tests/operators/test_fast_fourier_op.py index 2c2d1e020..d2db6fd69 100644 --- a/tests/operators/test_fast_fourier_op.py +++ b/tests/operators/test_fast_fourier_op.py @@ -61,7 +61,7 @@ def test_fast_fourier_op_adjoint(encoding_shape, recon_shape): y = torch.randn(encoding_shape, dtype=torch.complex64) # Create operator and apply - FFOp = FastFourierOp(encoding_shape=encoding_shape, recon_shape=recon_shape) + FFOp = FastFourierOp(recon_shape=recon_shape, encoding_shape=encoding_shape) Ax = FFOp.forward(x) AHy = FFOp.adjoint(y) From aa4c88a71b6f650ed70f2d265a7f83e0b948d8e8 Mon Sep 17 00:00:00 2001 From: ckolbPTB Date: Fri, 26 Jan 2024 15:02:28 +0100 Subject: [PATCH 4/5] review comments --- src/mrpro/operators/_FastFourierOp.py | 8 +- .../operators/{_PadOp.py => _ZeroPadOp.py} | 39 ++------ src/mrpro/operators/__init__.py | 2 +- src/mrpro/utils/__init__.py | 2 +- src/mrpro/utils/_change_data_shape.py | 49 ---------- src/mrpro/utils/_zero_pad_or_crop.py | 89 +++++++++++++++++++ tests/operators/test_fast_fourier_op.py | 2 +- .../{test_pad_op.py => test_zero_pad_op.py} | 12 +-- ...data_shape.py => test_zero_pad_or_crop.py} | 16 ++-- 9 files changed, 117 insertions(+), 102 deletions(-) rename src/mrpro/operators/{_PadOp.py => _ZeroPadOp.py} (69%) delete mode 100644 src/mrpro/utils/_change_data_shape.py create mode 100644 src/mrpro/utils/_zero_pad_or_crop.py rename tests/operators/{test_pad_op.py => test_zero_pad_op.py} (86%) rename tests/utils/{test_change_data_shape.py => test_zero_pad_or_crop.py} (66%) diff --git a/src/mrpro/operators/_FastFourierOp.py b/src/mrpro/operators/_FastFourierOp.py index 28eca44d5..bbb2ae39f 100644 --- a/src/mrpro/operators/_FastFourierOp.py +++ b/src/mrpro/operators/_FastFourierOp.py @@ -15,7 +15,7 @@ import torch from mrpro.operators import LinearOperator -from mrpro.operators import PadOp +from mrpro.operators import ZeroPadOp class FastFourierOp(LinearOperator): @@ -46,12 +46,12 @@ def __init__( """ super().__init__() self._dim: tuple[int, ...] = dim - self._pad_op: PadOp + self._pad_op: ZeroPadOp if encoding_shape is not None and recon_shape is not None: - self._pad_op = PadOp(dim=dim, orig_shape=recon_shape, padded_shape=encoding_shape) + self._pad_op = ZeroPadOp(dim=dim, orig_shape=recon_shape, padded_shape=encoding_shape) else: # No padding - self._pad_op = PadOp(dim=(), orig_shape=(), padded_shape=()) + self._pad_op = ZeroPadOp(dim=(), orig_shape=(), padded_shape=()) def forward(self, x: torch.Tensor) -> torch.Tensor: """FFT from image space to k-space. diff --git a/src/mrpro/operators/_PadOp.py b/src/mrpro/operators/_ZeroPadOp.py similarity index 69% rename from src/mrpro/operators/_PadOp.py rename to src/mrpro/operators/_ZeroPadOp.py index 5d5efb6f5..0fc3662a3 100644 --- a/src/mrpro/operators/_PadOp.py +++ b/src/mrpro/operators/_ZeroPadOp.py @@ -1,4 +1,4 @@ -"""Class for Pad Operator.""" +"""Class for Zero Pad Operator.""" # Copyright 2023 Physikalisch-Technische Bundesanstalt # @@ -15,11 +15,11 @@ import torch from mrpro.operators import LinearOperator -from mrpro.utils import change_data_shape +from mrpro.utils import zero_pad_or_crop -class PadOp(LinearOperator): - """Pad operator class.""" +class ZeroPadOp(LinearOperator): + """Zero Pad operator class.""" def __init__( self, @@ -27,7 +27,7 @@ def __init__( orig_shape: tuple[int, ...], padded_shape: tuple[int, ...], ) -> None: - """Pad Operator class. + """Zero Pad Operator class. The operator carries out zero-padding if the padded_shape is larger than orig_shape and cropping if the padded_shape is smaller. @@ -49,31 +49,6 @@ def __init__( self.orig_shape: tuple[int, ...] = orig_shape self.padded_shape: tuple[int, ...] = padded_shape - @staticmethod - def _pad_data(x: torch.Tensor, dim: tuple[int, ...], padded_shape: tuple[int, ...]) -> torch.Tensor: - """Pad or crop data. - - Parameters - ---------- - x - original data - dim - dim along which padding should be applied - padded_shape - shape of padded data - - Returns - ------- - data with shape padded_shape - """ - # Adapt image size - if len(dim) > 0: - s = list(x.shape) - for idx, idim in enumerate(dim): - s[idim] = padded_shape[idx] - x = change_data_shape(x, tuple(s)) - return x - def forward(self, x: torch.Tensor) -> torch.Tensor: """Pad or crop data. @@ -86,7 +61,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ------- data with shape padded_shape """ - return self._pad_data(x, self.dim, self.padded_shape) + return zero_pad_or_crop(x, self.padded_shape, self.dim) def adjoint(self, x: torch.Tensor) -> torch.Tensor: """Crop or pad data. @@ -100,4 +75,4 @@ def adjoint(self, x: torch.Tensor) -> torch.Tensor: ------- data with shape orig_shape """ - return self._pad_data(x, self.dim, self.orig_shape) + return zero_pad_or_crop(x, self.orig_shape, self.dim) diff --git a/src/mrpro/operators/__init__.py b/src/mrpro/operators/__init__.py index 041dc5de1..ec6b4c5e0 100644 --- a/src/mrpro/operators/__init__.py +++ b/src/mrpro/operators/__init__.py @@ -2,7 +2,7 @@ from mrpro.operators._LinearOperator import LinearOperator from mrpro.operators._NonLinearOperator import NonLinearOperator from mrpro.operators._SensitivityOp import SensitivityOp -from mrpro.operators._PadOp import PadOp +from mrpro.operators._ZeroPadOp import ZeroPadOp from mrpro.operators._FastFourierOp import FastFourierOp from mrpro.operators._FourierOp import FourierOp from mrpro.operators.models._WASABI import WASABI diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index eaf29570a..63df948e2 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -1,4 +1,4 @@ from mrpro.utils._smap import smap from mrpro.utils._remove_repeat import remove_repeat from mrpro.utils._rgetattr import rgetattr -from mrpro.utils._change_data_shape import change_data_shape +from mrpro.utils._zero_pad_or_crop import zero_pad_or_crop diff --git a/src/mrpro/utils/_change_data_shape.py b/src/mrpro/utils/_change_data_shape.py deleted file mode 100644 index 85d7dbe23..000000000 --- a/src/mrpro/utils/_change_data_shape.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Wrapper for FFT and IFFT.""" - -# Copyright 2023 Physikalisch-Technische Bundesanstalt -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import torch -import torch.nn.functional as F - - -def change_data_shape(dat: torch.Tensor, dat_shape_new: tuple[int, ...]) -> torch.Tensor: - """Change shape of data by cropping or zero-padding. - - Parameters - ---------- - dat - data - dat_shape_new - desired shape of data - - Returns - ------- - data with shape dat_shape_new - """ - s = list(dat.shape) - # Padding - npad = [0] * (2 * len(s)) - - for idx in range(len(s)): - if s[idx] != dat_shape_new[idx]: - dim_diff = dat_shape_new[idx] - s[idx] - # This is needed to ensure that padding and cropping leads to the same asymetry for odd shape differences - npad[2 * idx] = np.sign(dim_diff) * (np.abs(dim_diff) // 2) - npad[2 * idx + 1] = dat_shape_new[idx] - (s[idx] + npad[2 * idx]) - - # Pad (positive npad) or crop (negative npad) - # npad has to be reversed because pad expects it in reversed order - if not torch.all(torch.tensor(npad) == 0): - dat = F.pad(dat, npad[::-1]) - return dat diff --git a/src/mrpro/utils/_zero_pad_or_crop.py b/src/mrpro/utils/_zero_pad_or_crop.py new file mode 100644 index 000000000..d33a96051 --- /dev/null +++ b/src/mrpro/utils/_zero_pad_or_crop.py @@ -0,0 +1,89 @@ +"""Zero pad and crop data tensor.""" + +# Copyright 2023 Physikalisch-Technische Bundesanstalt +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn.functional as F + + +def normalize_index(ndim: int, index: int): + """Normalize possibly negative indices. + + Parameters + ---------- + ndim + number of dimensions + index + index to normalize. negative indices count from the end. + + Raises + ------ + IndexError + if index is outside [-ndim,ndim) + """ + if 0 < index < ndim: + return index + elif -ndim <= index < 0: + return ndim + index + else: + raise IndexError(f'Invalid index {index} for {ndim} data dimensions') + + +def zero_pad_or_crop( + data: torch.Tensor, new_shape: tuple[int, ...] | torch.Size, dim: None | tuple[int, ...] = None +) -> torch.Tensor: + """Change shape of data by cropping or zero-padding. + + Parameters + ---------- + data + data + new_shape + desired shape of data + dim: + dimensions the new_shape corresponds to. None (default) is interpreted as last len(new_shape) dimensions. + + Returns + ------- + data zero padded or cropped to shape + """ + + if len(new_shape) > data.ndim: + raise ValueError('length of new shape should not exceed dimensions of data') + + if dim is None: # Use last dimensions + new_shape = data.shape[: -len(new_shape)] + new_shape + else: + if len(new_shape) != len(dim): + raise ValueError('length of shape should match length of dim') + dim = tuple(normalize_index(data.ndim, idx) for idx in dim) # raises if any not in [-data.ndim,data.ndim) + if len(dim) != len(set(dim)): # this is why we normalize + raise ValueError('repeated values are not allowed in dims') + new_shape_full = torch.tensor(data.shape) + for i, d in enumerate(dim): + new_shape_full[d] = new_shape[i] + + npad = [] + for old, new in zip(torch.tensor(data.shape), new_shape_full): + diff = (new - old).numpy() # __trunc__ method not available for tensors + after = math.trunc(diff / 2) + before = diff - after + npad.append(before) + npad.append(after) + + if any(npad): + # F.pad expects paddings in reversed order + data = F.pad(data, npad[::-1]) + return data diff --git a/tests/operators/test_fast_fourier_op.py b/tests/operators/test_fast_fourier_op.py index d2db6fd69..ebfbdd5cc 100644 --- a/tests/operators/test_fast_fourier_op.py +++ b/tests/operators/test_fast_fourier_op.py @@ -21,7 +21,7 @@ @pytest.mark.parametrize('npoints, a', [(100, 20), (300, 20)]) def test_fast_fourier_op_forward(npoints, a): - """Test Fast Fourier Op transformation using a Gaussian.""" + """Test Fast Fourier Op transformation using a Gaussian.""" # Utilize that a Fourier transform of a Gaussian function is given by # F(exp(-x^2/a)) = sqrt(pi*a)exp(-a*pi^2k^2) diff --git a/tests/operators/test_pad_op.py b/tests/operators/test_zero_pad_op.py similarity index 86% rename from tests/operators/test_pad_op.py rename to tests/operators/test_zero_pad_op.py index 7f40db45e..5a55fa059 100644 --- a/tests/operators/test_pad_op.py +++ b/tests/operators/test_zero_pad_op.py @@ -1,4 +1,4 @@ -"""Tests for Pad Operator class.""" +"""Tests for Zero Pad Operator class.""" # Copyright 2023 Physikalisch-Technische Bundesanstalt # @@ -16,16 +16,16 @@ import pytest import torch -from mrpro.operators import PadOp +from mrpro.operators import ZeroPadOp -def test_pad_op_content(): +def test_zero_pad_op_content(): """Test correct padding.""" dshape_orig = (2, 100, 3, 200, 50, 2) dshape_new = (2, 80, 3, 100, 240, 2) dorig = torch.randn(*dshape_orig, dtype=torch.complex64) pad_dim = (-5, -3, -2) - POp = PadOp( + POp = ZeroPadOp( dim=pad_dim, orig_shape=tuple([dshape_orig[d] for d in pad_dim]), padded_shape=tuple([dshape_new[d] for d in pad_dim]), @@ -45,11 +45,11 @@ def test_pad_op_content(): ((100, 200, 50), (13, 221, 64)), ], ) -def test_pad_op_ajoint(u_shape, v_shape): +def test_zero_pad_op_ajoint(u_shape, v_shape): """Test adjointness of pad operator.""" u = torch.randn(u_shape, dtype=torch.complex64) v = torch.randn(v_shape, dtype=torch.complex64) - POp = PadOp(dim=(-3, -2, -1), orig_shape=u_shape, padded_shape=v_shape) + POp = ZeroPadOp(dim=(-3, -2, -1), orig_shape=u_shape, padded_shape=v_shape) Au = POp.forward(u) AHv = POp.adjoint(v) diff --git a/tests/utils/test_change_data_shape.py b/tests/utils/test_zero_pad_or_crop.py similarity index 66% rename from tests/utils/test_change_data_shape.py rename to tests/utils/test_zero_pad_or_crop.py index e19493317..790a52983 100644 --- a/tests/utils/test_change_data_shape.py +++ b/tests/utils/test_zero_pad_or_crop.py @@ -1,4 +1,4 @@ -"""Tests for image space - k-space transformations.""" +"""Tests for zero padding and cropping of data tensors.""" # Copyright 2023 Physikalisch-Technische Bundesanstalt # @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np -import pytest import torch -from mrpro.utils._change_data_shape import change_data_shape +from mrpro.utils._zero_pad_or_crop import zero_pad_or_crop +from tests import RandomGenerator -def test_change_data_shape_content(): - """Test changing data size by cropping and padding.""" +def test_zero_pad_or_crop_content(): + """Test changing data by cropping and padding.""" + generator = RandomGenerator(seed=0) dshape_orig = (100, 200, 50) dshape_new = (80, 100, 240) - dorig = torch.randn(*dshape_orig, dtype=torch.complex64) - dnew = change_data_shape(dorig, dshape_new) + dorig = generator.complex64_tensor(dshape_orig) + dnew = zero_pad_or_crop(dorig, dshape_new, dim=(-3, -2, -1)) # Compare overlapping region torch.testing.assert_close(dorig[10:90, 50:150, :], dnew[:, :, 95:145]) From 77d49bb69778711a97152a94cec9446b29a95071 Mon Sep 17 00:00:00 2001 From: ckolbPTB Date: Fri, 26 Jan 2024 17:20:02 +0100 Subject: [PATCH 5/5] review comments --- tests/operators/test_fast_fourier_op.py | 6 ++++-- tests/operators/test_zero_pad_op.py | 9 ++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/operators/test_fast_fourier_op.py b/tests/operators/test_fast_fourier_op.py index ebfbdd5cc..72c291ed5 100644 --- a/tests/operators/test_fast_fourier_op.py +++ b/tests/operators/test_fast_fourier_op.py @@ -17,6 +17,7 @@ import torch from mrpro.operators import FastFourierOp +from tests import RandomGenerator @pytest.mark.parametrize('npoints, a', [(100, 20), (300, 20)]) @@ -57,8 +58,9 @@ def test_fast_fourier_op_adjoint(encoding_shape, recon_shape): """Test adjointness of Fast Fourier Op.""" # Create test data - x = torch.randn(recon_shape, dtype=torch.complex64) - y = torch.randn(encoding_shape, dtype=torch.complex64) + generator = RandomGenerator(seed=0) + x = generator.complex64_tensor(recon_shape) + y = generator.complex64_tensor(encoding_shape) # Create operator and apply FFOp = FastFourierOp(recon_shape=recon_shape, encoding_shape=encoding_shape) diff --git a/tests/operators/test_zero_pad_op.py b/tests/operators/test_zero_pad_op.py index 5a55fa059..81d0c17e3 100644 --- a/tests/operators/test_zero_pad_op.py +++ b/tests/operators/test_zero_pad_op.py @@ -17,13 +17,15 @@ import torch from mrpro.operators import ZeroPadOp +from tests import RandomGenerator def test_zero_pad_op_content(): """Test correct padding.""" dshape_orig = (2, 100, 3, 200, 50, 2) dshape_new = (2, 80, 3, 100, 240, 2) - dorig = torch.randn(*dshape_orig, dtype=torch.complex64) + generator = RandomGenerator(seed=0) + dorig = generator.complex64_tensor(dshape_orig) pad_dim = (-5, -3, -2) POp = ZeroPadOp( dim=pad_dim, @@ -47,8 +49,9 @@ def test_zero_pad_op_content(): ) def test_zero_pad_op_ajoint(u_shape, v_shape): """Test adjointness of pad operator.""" - u = torch.randn(u_shape, dtype=torch.complex64) - v = torch.randn(v_shape, dtype=torch.complex64) + generator = RandomGenerator(seed=0) + u = generator.complex64_tensor(u_shape) + v = generator.complex64_tensor(v_shape) POp = ZeroPadOp(dim=(-3, -2, -1), orig_shape=u_shape, padded_shape=v_shape) Au = POp.forward(u) AHv = POp.adjoint(v)