From facf8b60a43fe4ab252789471bc51128eb07c6da Mon Sep 17 00:00:00 2001 From: hanjinliu Date: Thu, 20 Jan 2022 20:51:45 +0900 Subject: [PATCH] dft bug fix, v1.25.0 --- impy/__init__.py | 2 +- impy/arrays/_utils/_misc.py | 1 + impy/arrays/imgarray.py | 2 +- tests/test_fft.py | 6 +++--- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/impy/__init__.py b/impy/__init__.py index 0d089b34..b78d69ca 100644 --- a/impy/__init__.py +++ b/impy/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.25.0.dev0" +__version__ = "1.25.0" __author__ = "Hanjin Liu", __email__ = "liuhanjin-sc@g.ecc.u-tokyo.ac.jp" diff --git a/impy/arrays/_utils/_misc.py b/impy/arrays/_utils/_misc.py index 5bed0191..57d64120 100644 --- a/impy/arrays/_utils/_misc.py +++ b/impy/arrays/_utils/_misc.py @@ -59,6 +59,7 @@ def make_pad(pad_width, dims, all_axes, **kwargs): return pad_width_ def dft(img: xp.ndarray, exps: list[xp.ndarray] = None): + img = xp.asarray(img) for ker in reversed(exps): # K_{kx} * I_{zyx} img = xp.tensordot(ker, img, axes=(1, -1)) diff --git a/impy/arrays/imgarray.py b/impy/arrays/imgarray.py index 9fd3692d..35a4c970 100644 --- a/impy/arrays/imgarray.py +++ b/impy/arrays/imgarray.py @@ -2823,7 +2823,7 @@ def local_dft(self, key: str = "", upsample_factor: nDInt = 1, *, double_precisi # To minimize floating error, the A term in exp(-2*pi*i*A) should be in the range of # 0 <= A < 1. exps: list[xp.ndarray] = \ - [xp.exp(-2j * np.pi * np.mod(wave_num(sl, s, uf) * xp.arange(s)/s, 1.), dtype=dtype) + [xp.exp(-2j * np.pi * xp.mod(wave_num(sl, s, uf) * xp.arange(s)/s, 1.), dtype=dtype) for sl, s, uf in zip(slices, self.sizesof(dims), upsample_factor)] # Calculate chunk size for proper output shapes diff --git a/tests/test_fft.py b/tests/test_fft.py index 9086d52c..ded0c215 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -1,7 +1,7 @@ import impy as ip from pathlib import Path import numpy as np -from impy._cupy import xp, xp_fft +from impy._cupy import xp, xp_fft, asnumpy from numpy.testing import assert_allclose ip.Const["SHOW_PROGRESS"] = False @@ -10,11 +10,11 @@ def test_precision(): img = ip.imread(path)["c=1;t=0"].as_float() assert_allclose(img.fft(shift=False), - xp_fft.fftn(xp.asarray(img.value)) + asnumpy(xp_fft.fftn(xp.asarray(img.value))) ) assert_allclose(img.fft(shift=False, double_precision=True), - xp_fft.fftn(xp.asarray(img.value.astype(np.float64))).astype(np.complex64) + asnumpy(xp_fft.fftn(xp.asarray(img.value.astype(np.float64))).astype(np.complex64)) ) assert_allclose(img.fft().ifft(), img, rtol=1e-6)