diff --git a/tests/operators/test_fourier_op.py b/tests/operators/test_fourier_op.py index 28fd03969..23479a0b6 100644 --- a/tests/operators/test_fourier_op.py +++ b/tests/operators/test_fourier_op.py @@ -11,6 +11,18 @@ from tests.conftest import COMMON_MR_TRAJECTORIES, create_traj +class NufftTrajektory(KTrajectory): + """Always returns non-grid trajectory type.""" + + def _traj_types( + self, + tolerance: float, + ) -> tuple[tuple[TrajType, TrajType, TrajType], tuple[TrajType, TrajType, TrajType]]: + true_types = super()._traj_types(tolerance) + modified = tuple([tuple([t & (~TrajType.ONGRID) for t in ts]) for ts in true_types]) + return modified + + def create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz): random_generator = RandomGenerator(seed=0) @@ -39,13 +51,10 @@ def test_fourier_op_fwd_adj_property( ) fourier_op = FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) - # apply forward operator - (kdata,) = fourier_op(img) - # test adjoint property; i.e. == for all u,v random_generator = RandomGenerator(seed=0) - u = random_generator.complex64_tensor(size=img.shape) - v = random_generator.complex64_tensor(size=kdata.shape) + u = random_generator.complex64_tensor(size=im_shape) + v = random_generator.complex64_tensor(size=k_shape) dotproduct_adjointness_test(fourier_op, u, v) @@ -64,19 +73,21 @@ def test_fourier_op_norm(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, typ int(trajectory.kx.max() - trajectory.kx.min() + 1), ) fourier_op = FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) - # only do few iterations to speed up the test - norm = fourier_op.operator_norm(img, dim=None, max_iterations=20) - torch.testing.assert_close(norm.squeeze(), torch.tensor(1.0), atol=0.1, rtol=0.0) + (initial_value,) = fourier_op.adjoint(*fourier_op(img)) + norm = fourier_op.operator_norm(initial_value, dim=None, max_iterations=4).squeeze() + torch.testing.assert_close(norm, torch.tensor(1.0), atol=0.1, rtol=0.0) @COMMON_MR_TRAJECTORIES -def test_fourier_op_fft_nufft(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2): +def test_fourier_op_fft_nufft_forward( + im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2 +): """Test Nufft vs FFT for Fourier operator.""" + if not any(t == 'uniform' for t in [type_kx, type_ky, type_kz]): + return # only test for uniform trajectories - # generate random images and k-space trajectories img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) - # create operator recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) encoding_matrix = SpatialDimension( int(trajectory.kz.max() - trajectory.kz.min() + 1), @@ -85,35 +96,45 @@ def test_fourier_op_fft_nufft(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky ) fourier_op = FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) - class NufftTrajektory(KTrajectory): - """Always returns non-grid trajectory type.""" - - def _traj_types( - self, - tolerance: float, - ) -> tuple[tuple[TrajType, TrajType, TrajType], tuple[TrajType, TrajType, TrajType]]: - true_types = super()._traj_types(tolerance) - modified = tuple([tuple([t & (~TrajType.ONGRID) for t in ts]) for ts in true_types]) - return modified - nufft_fourier_op = FourierOp( recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=NufftTrajektory(trajectory.kz, trajectory.ky, trajectory.kx), + nufft_oversampling=8.0, ) (result_normal,) = fourier_op(img) (result_nufft,) = nufft_fourier_op(img) - torch.testing.assert_close(result_normal, result_nufft, atol=1e-5, rtol=1e-4) + torch.testing.assert_close(result_normal, result_nufft, atol=1e-4, rtol=5e-3) + + +@COMMON_MR_TRAJECTORIES +def test_fourier_op_fft_nufft_adjoint( + im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz, type_k0, type_k1, type_k2 +): + """Test AdjointNufft vs IFFT for Fourier operator.""" + if not any(t == 'uniform' for t in [type_kx, type_ky, type_kz]): + return # only test for uniform trajectories + img, trajectory = create_data(im_shape, k_shape, nkx, nky, nkz, type_kx, type_ky, type_kz) + recon_matrix = SpatialDimension(im_shape[-3], im_shape[-2], im_shape[-1]) + encoding_matrix = SpatialDimension( + int(trajectory.kz.max() - trajectory.kz.min() + 1), + int(trajectory.ky.max() - trajectory.ky.min() + 1), + int(trajectory.kx.max() - trajectory.kx.min() + 1), + ) + fourier_op = FourierOp(recon_matrix=recon_matrix, encoding_matrix=encoding_matrix, traj=trajectory) + + nufft_fourier_op = FourierOp( + recon_matrix=recon_matrix, + encoding_matrix=encoding_matrix, + traj=NufftTrajektory(trajectory.kz, trajectory.ky, trajectory.kx), + nufft_oversampling=8.0, + ) k = RandomGenerator(0).complex64_tensor(size=k_shape) (result_normal,) = fourier_op.H(k) (result_nufft,) = nufft_fourier_op.H(k) - torch.testing.assert_close(result_normal, result_nufft, atol=1e-5, rtol=1e-4) - - (result_normal,) = fourier_op(img) - (result_nufft,) = nufft_fourier_op(img) - torch.testing.assert_close(result_normal, result_nufft, atol=1e-5, rtol=1e-4) + torch.testing.assert_close(result_normal, result_nufft, atol=3e-4, rtol=5e-3) @COMMON_MR_TRAJECTORIES