Skip to content

Commit

Permalink
minor: added ncp to fftnd
Browse files Browse the repository at this point in the history
  • Loading branch information
mrava87 committed Dec 21, 2023
1 parent d3b3b66 commit 780efa3
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions pylops/signalprocessing/fftnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy.typing as npt

from pylops.signalprocessing._baseffts import _BaseFFTND, _FFTNorms
from pylops.utils.backend import get_sp_fft
from pylops.utils.backend import get_array_module, get_sp_fft
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray

Expand Down Expand Up @@ -56,50 +56,52 @@ def __init__(

@reshaped
def _matvec(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
if self.ifftshift_before.any():
x = np.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
x = ncp.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
if not self.clinear:
x = np.real(x)
x = ncp.real(x)
if self.real:
y = np.fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = ncp.fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
# Apply scaling to obtain a correct adjoint for this operator
y = np.swapaxes(y, -1, self.axes[-1])
y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2)
y = np.swapaxes(y, self.axes[-1], -1)
y = ncp.swapaxes(y, -1, self.axes[-1])
y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= ncp.sqrt(2)
y = ncp.swapaxes(y, self.axes[-1], -1)
else:
y = np.fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = ncp.fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.ONE_OVER_N:
y *= self._scale
y = y.astype(self.cdtype)
if self.fftshift_after.any():
y = np.fft.fftshift(y, axes=self.axes[self.fftshift_after])
y = ncp.fft.fftshift(y, axes=self.axes[self.fftshift_after])
return y

@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
if self.fftshift_after.any():
x = np.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
x = ncp.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
if self.real:
# Apply scaling to obtain a correct adjoint for this operator
x = x.copy()
x = np.swapaxes(x, -1, self.axes[-1])
x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2)
x = np.swapaxes(x, self.axes[-1], -1)
y = np.fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
x = ncp.swapaxes(x, -1, self.axes[-1])
x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= ncp.sqrt(2)
x = ncp.swapaxes(x, self.axes[-1], -1)
y = ncp.fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
else:
y = np.fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = ncp.fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.NONE:
y *= self._scale
for ax, nfft in zip(self.axes, self.nffts):
if nfft > self.dims[ax]:
y = np.take(y, range(self.dims[ax]), axis=ax)
y = ncp.take(y, range(self.dims[ax]), axis=ax)
if self.doifftpad:
y = np.pad(y, self.ifftpad)
y = ncp.pad(y, self.ifftpad)
if not self.clinear:
y = np.real(y)
y = ncp.real(y)
y = y.astype(self.rdtype)
if self.ifftshift_before.any():
y = np.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
y = ncp.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
return y

def __truediv__(self, y: npt.ArrayLike) -> npt.ArrayLike:
Expand Down

0 comments on commit 780efa3

Please sign in to comment.