diff --git a/pylops/signalprocessing/fft.py b/pylops/signalprocessing/fft.py index 7a3051b3..9cf2cd72 100644 --- a/pylops/signalprocessing/fft.py +++ b/pylops/signalprocessing/fft.py @@ -3,9 +3,6 @@ import logging import warnings from typing import Optional, Union -from mkl_fft import _numpy_fft as pymkl_fft -from mkl_fft._scipy_fft_backend import fftshift as mkl_fftshift, ifftshift as mkl_ifftshift - import numpy as np import numpy.typing as npt @@ -405,6 +402,9 @@ def __init__( @reshaped def _matvec(self, x: NDArray) -> NDArray: + from mkl_fft import _numpy_fft as pymkl_fft + from mkl_fft._scipy_fft_backend import fftshift as mkl_fftshift, ifftshift as mkl_ifftshift + if self.ifftshift_before: x = mkl_ifftshift(x, axes=self.axis) if not self.clinear: @@ -425,6 +425,9 @@ def _matvec(self, x: NDArray) -> NDArray: @reshaped def _rmatvec(self, x: NDArray) -> NDArray: + from mkl_fft import _numpy_fft as pymkl_fft + from mkl_fft._scipy_fft_backend import fftshift as mkl_fftshift, ifftshift as mkl_ifftshift + if self.fftshift_after: x = mkl_ifftshift(x, axes=self.axis) if self.real: diff --git a/pylops/signalprocessing/fft2d.py b/pylops/signalprocessing/fft2d.py index e1957b25..8cf39470 100644 --- a/pylops/signalprocessing/fft2d.py +++ b/pylops/signalprocessing/fft2d.py @@ -5,7 +5,6 @@ from typing import Dict, Optional, Sequence, Union from pylops.signalprocessing.fftnd import _FFTND_mklfft -from mkl_fft._scipy_fft_backend import fftshift as mkl_fftshift, ifftshift as mkl_ifftshift import numpy as np import scipy.fft @@ -266,6 +265,8 @@ def __init__( @reshaped def _matvec(self, x): + from mkl_fft._scipy_fft_backend import fftshift as mkl_fftshift, ifftshift as mkl_ifftshift + if self.ifftshift_before.any(): x = mkl_ifftshift(x, axes=self.axes[self.ifftshift_before]) if not self.clinear: @@ -286,6 +287,8 @@ def _matvec(self, x): @reshaped def _rmatvec(self, x): + from mkl_fft._scipy_fft_backend import fftshift as mkl_fftshift, ifftshift as mkl_ifftshift + if self.fftshift_after.any(): x = mkl_ifftshift(x, axes=self.axes[self.fftshift_after]) if self.real: diff --git a/pylops/signalprocessing/fftnd.py b/pylops/signalprocessing/fftnd.py index ae00b6e1..095ce672 100644 --- a/pylops/signalprocessing/fftnd.py +++ b/pylops/signalprocessing/fftnd.py @@ -238,7 +238,7 @@ def _matvec(self, x: NDArray) -> NDArray: from mkl_fft._scipy_fft_backend import fftshift as mkl_fftshift, ifftshift as mkl_iffshift if self.ifftshift_before.any(): - x = mkl_fftshift(x, axes=self.axes[self.ifftshift_before]) + x = mkl_iffshift(x, axes=self.axes[self.ifftshift_before]) if not self.clinear: x = np.real(x) if self.real: @@ -252,7 +252,7 @@ def _matvec(self, x: NDArray) -> NDArray: if self.norm is _FFTNorms.ONE_OVER_N: y *= self._scale if self.fftshift_after.any(): - y = mkl_iffshift(y, axes=self.axes[self.fftshift_after]) + y = mkl_fftshift(y, axes=self.axes[self.fftshift_after]) return y @reshaped