Skip to content

Commit

Permalink
dft bug fix, v1.25.0
Browse files Browse the repository at this point in the history
  • Loading branch information
hanjinliu committed Jan 20, 2022
1 parent 6c015c5 commit facf8b6
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion impy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.25.0.dev0"
__version__ = "1.25.0"
__author__ = "Hanjin Liu",
__email__ = "[email protected]"

Expand Down
1 change: 1 addition & 0 deletions impy/arrays/_utils/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def make_pad(pad_width, dims, all_axes, **kwargs):
return pad_width_

def dft(img: xp.ndarray, exps: list[xp.ndarray] = None):
img = xp.asarray(img)
for ker in reversed(exps):
# K_{kx} * I_{zyx}
img = xp.tensordot(ker, img, axes=(1, -1))
Expand Down
2 changes: 1 addition & 1 deletion impy/arrays/imgarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2823,7 +2823,7 @@ def local_dft(self, key: str = "", upsample_factor: nDInt = 1, *, double_precisi
# To minimize floating error, the A term in exp(-2*pi*i*A) should be in the range of
# 0 <= A < 1.
exps: list[xp.ndarray] = \
[xp.exp(-2j * np.pi * np.mod(wave_num(sl, s, uf) * xp.arange(s)/s, 1.), dtype=dtype)
[xp.exp(-2j * np.pi * xp.mod(wave_num(sl, s, uf) * xp.arange(s)/s, 1.), dtype=dtype)
for sl, s, uf in zip(slices, self.sizesof(dims), upsample_factor)]

# Calculate chunk size for proper output shapes
Expand Down
6 changes: 3 additions & 3 deletions tests/test_fft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import impy as ip
from pathlib import Path
import numpy as np
from impy._cupy import xp, xp_fft
from impy._cupy import xp, xp_fft, asnumpy
from numpy.testing import assert_allclose
ip.Const["SHOW_PROGRESS"] = False

Expand All @@ -10,11 +10,11 @@ def test_precision():
img = ip.imread(path)["c=1;t=0"].as_float()

assert_allclose(img.fft(shift=False),
xp_fft.fftn(xp.asarray(img.value))
asnumpy(xp_fft.fftn(xp.asarray(img.value)))
)

assert_allclose(img.fft(shift=False, double_precision=True),
xp_fft.fftn(xp.asarray(img.value.astype(np.float64))).astype(np.complex64)
asnumpy(xp_fft.fftn(xp.asarray(img.value.astype(np.float64))).astype(np.complex64))
)

assert_allclose(img.fft().ifft(), img, rtol=1e-6)
Expand Down

0 comments on commit facf8b6

Please sign in to comment.