diff --git a/.github/workflows/run_tests.sh b/.github/workflows/run_tests.sh index 53df4e2d36..bfa9613840 100755 --- a/.github/workflows/run_tests.sh +++ b/.github/workflows/run_tests.sh @@ -4,7 +4,7 @@ export PYTHONPATH=$(pwd):$PYTHONPATH ulimit -s 20000 mkdir -p pyscftmpdir -echo 'pbc_tools_pbc_fft_engine = "NUMPY"' > .pyscf_conf.py +echo 'pbc_tools_pbc_fft_engine = "NUMPY+BLAS"' > .pyscf_conf.py echo "dftd3_DFTD3PATH = './pyscf/lib/deps/lib'" >> .pyscf_conf.py echo "scf_hf_SCF_mute_chkfile = True" >> .pyscf_conf.py echo 'TMPDIR = "./pyscftmpdir"' >> .pyscf_conf.py diff --git a/pyproject.toml b/pyproject.toml index 5771f2c9f7..2dbc77ead1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ license = { text = "Apache-2.0" } dependencies = [ 'numpy>=1.13,!=1.16,!=1.17', - 'scipy!=1.5.0,!=1.5.1', + 'scipy>=1.6.0', 'h5py>=2.7', 'setuptools', ] diff --git a/pyscf/pbc/tools/pbc.py b/pyscf/pbc/tools/pbc.py index 504daa85b3..10e2acefdc 100644 --- a/pyscf/pbc/tools/pbc.py +++ b/pyscf/pbc/tools/pbc.py @@ -16,6 +16,7 @@ import warnings import ctypes import numpy as np +import scipy import scipy.linalg from pyscf import lib from pyscf.lib import logger @@ -23,39 +24,55 @@ from pyscf.pbc.lib.kpts_helper import get_kconserv, get_kconserv3 # noqa from pyscf import __config__ -FFT_ENGINE = getattr(__config__, 'pbc_tools_pbc_fft_engine', 'BLAS') +FFT_ENGINE = getattr(__config__, 'pbc_tools_pbc_fft_engine', 'NUMPY+BLAS') def _fftn_blas(f, mesh): - Gx = np.fft.fftfreq(mesh[0]) - Gy = np.fft.fftfreq(mesh[1]) - Gz = np.fft.fftfreq(mesh[2]) - expRGx = np.exp(np.einsum('x,k->xk', -2j*np.pi*np.arange(mesh[0]), Gx)) - expRGy = np.exp(np.einsum('x,k->xk', -2j*np.pi*np.arange(mesh[1]), Gy)) - expRGz = np.exp(np.einsum('x,k->xk', -2j*np.pi*np.arange(mesh[2]), Gz)) - out = np.empty(f.shape, dtype=np.complex128) - buf = np.empty(mesh, dtype=np.complex128) - for i, fi in enumerate(f): - buf[:] = fi.reshape(mesh) - g = lib.dot(buf.reshape(mesh[0],-1).T, expRGx, c=out[i].reshape(-1,mesh[0])) - g = lib.dot(g.reshape(mesh[1],-1).T, expRGy, c=buf.reshape(-1,mesh[1])) - g = lib.dot(g.reshape(mesh[2],-1).T, expRGz, c=out[i].reshape(-1,mesh[2])) - return out.reshape(-1, *mesh) + assert f.ndim == 4 + mx, my, mz = mesh + expRGx = np.exp(-2j*np.pi*np.arange(mx)[:,None] * np.fft.fftfreq(mx)) + expRGy = np.exp(-2j*np.pi*np.arange(my)[:,None] * np.fft.fftfreq(my)) + expRGz = np.exp(-2j*np.pi*np.arange(mz)[:,None] * np.fft.fftfreq(mz)) + blksize = max(int(1e5 / (mx * my * mz)), 8) * 4 + n = f.shape[0] + out = np.empty((n, mx*my*mz), dtype=np.complex128) + buf = np.empty((blksize, mx*my*mz), dtype=np.complex128) + for i0, i1 in lib.prange(0, n, blksize): + ni = i1 - i0 + buf1 = buf[:ni] + out1 = out[i0:i1] + g = lib.transpose(f[i0:i1].reshape(ni,-1), out=buf1.reshape(-1,ni)) + g = lib.dot(g.reshape(mx,-1).T, expRGx, c=out1.reshape(-1,mx)) + g = lib.dot(g.reshape(my,-1).T, expRGy, c=buf1.reshape(-1,my)) + g = lib.dot(g.reshape(mz,-1).T, expRGz, c=out1.reshape(-1,mz)) + return out.reshape(n, *mesh) def _ifftn_blas(g, mesh): - Gx = np.fft.fftfreq(mesh[0]) - Gy = np.fft.fftfreq(mesh[1]) - Gz = np.fft.fftfreq(mesh[2]) - expRGx = np.exp(np.einsum('x,k->xk', 2j*np.pi*np.arange(mesh[0]), Gx)) - expRGy = np.exp(np.einsum('x,k->xk', 2j*np.pi*np.arange(mesh[1]), Gy)) - expRGz = np.exp(np.einsum('x,k->xk', 2j*np.pi*np.arange(mesh[2]), Gz)) - out = np.empty(g.shape, dtype=np.complex128) - buf = np.empty(mesh, dtype=np.complex128) - for i, gi in enumerate(g): - buf[:] = gi.reshape(mesh) - f = lib.dot(buf.reshape(mesh[0],-1).T, expRGx, 1./mesh[0], c=out[i].reshape(-1,mesh[0])) - f = lib.dot(f.reshape(mesh[1],-1).T, expRGy, 1./mesh[1], c=buf.reshape(-1,mesh[1])) - f = lib.dot(f.reshape(mesh[2],-1).T, expRGz, 1./mesh[2], c=out[i].reshape(-1,mesh[2])) - return out.reshape(-1, *mesh) + assert g.ndim == 4 + mx, my, mz = mesh + expRGx = np.exp(2j*np.pi*np.fft.fftfreq(mx)[:,None] * np.arange(mx)) + expRGy = np.exp(2j*np.pi*np.fft.fftfreq(my)[:,None] * np.arange(my)) + expRGz = np.exp(2j*np.pi*np.fft.fftfreq(mz)[:,None] * np.arange(mz)) + blksize = max(int(1e5 / (mx * my * mz)), 8) * 4 + n = g.shape[0] + out = np.empty((n, mx*my*mz), dtype=np.complex128) + buf = np.empty((blksize, mx*my*mz), dtype=np.complex128) + for i0, i1 in lib.prange(0, n, blksize): + ni = i1 - i0 + buf1 = buf[:ni] + out1 = out[i0:i1] + f = lib.transpose(g[i0:i1].reshape(ni,-1), out=buf1.reshape(-1,ni)) + f = lib.dot(f.reshape(mx,-1).T, expRGx, 1./mx, c=out1.reshape(-1,mx)) + f = lib.dot(f.reshape(my,-1).T, expRGy, 1./my, c=buf1.reshape(-1,my)) + f = lib.dot(f.reshape(mz,-1).T, expRGz, 1./mz, c=out1.reshape(-1,mz)) + return out.reshape(n, *mesh) + +nproc = lib.num_threads() + +def _fftn_wrapper(a): # noqa + return scipy.fft.fftn(a, axes=(1,2,3), workers=nproc) + +def _ifftn_wrapper(a): # noqa + return scipy.fft.ifftn(a, axes=(1,2,3), workers=nproc) if FFT_ENGINE == 'FFTW': try: @@ -88,60 +105,50 @@ def _complex_fftn_fftw(f, mesh, func): ctypes.c_int(rank)) return out - def _fftn_wrapper(a): + def _fftn_wrapper(a): # noqa mesh = a.shape[1:] return _complex_fftn_fftw(a, mesh, 'fft') - def _ifftn_wrapper(a): + def _ifftn_wrapper(a): # noqa mesh = a.shape[1:] return _complex_fftn_fftw(a, mesh, 'ifft') elif FFT_ENGINE == 'PYFFTW': - # pyfftw is slower than np.fft in most cases + # Note: pyfftw is likely slower than scipy.fft in multi-threading environments try: import pyfftw + pyfftw.config.PLANNER_EFFORT = 'FFTW_MEASURE' pyfftw.interfaces.cache.enable() - nproc = lib.num_threads() - def _fftn_wrapper(a): + def _fftn_wrapper(a): # noqa return pyfftw.interfaces.numpy_fft.fftn(a, axes=(1,2,3), threads=nproc) - def _ifftn_wrapper(a): + def _ifftn_wrapper(a): # noqa return pyfftw.interfaces.numpy_fft.ifftn(a, axes=(1,2,3), threads=nproc) except ImportError: - def _fftn_wrapper(a): - return np.fft.fftn(a, axes=(1,2,3)) - def _ifftn_wrapper(a): - return np.fft.ifftn(a, axes=(1,2,3)) - -elif FFT_ENGINE == 'NUMPY': - def _fftn_wrapper(a): - return np.fft.fftn(a, axes=(1,2,3)) - def _ifftn_wrapper(a): - return np.fft.ifftn(a, axes=(1,2,3)) + print('PyFFTW not installed. SciPy fft module will be used.') elif FFT_ENGINE == 'NUMPY+BLAS': _EXCLUDE = [17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97,101,103,107,109,113,127,131,137,139,149,151,157,163, 167,173,179,181,191,193,197,199,211,223,227,229,233,239,241,251, 257,263,269,271,277,281,283,293] - _EXCLUDE = set(_EXCLUDE + [n*2 for n in _EXCLUDE] + [n*3 for n in _EXCLUDE]) - def _fftn_wrapper(a): + _EXCLUDE = set(_EXCLUDE + [n*2 for n in _EXCLUDE[:30]] + [n*3 for n in _EXCLUDE[:20]]) + def _fftn_wrapper(a): # noqa mesh = a.shape[1:] if mesh[0] in _EXCLUDE and mesh[1] in _EXCLUDE and mesh[2] in _EXCLUDE: return _fftn_blas(a, mesh) else: - return np.fft.fftn(a, axes=(1,2,3)) - def _ifftn_wrapper(a): + return scipy.fft.fftn(a, axes=(1,2,3), workers=nproc) + def _ifftn_wrapper(a): # noqa mesh = a.shape[1:] if mesh[0] in _EXCLUDE and mesh[1] in _EXCLUDE and mesh[2] in _EXCLUDE: return _ifftn_blas(a, mesh) else: - return np.fft.ifftn(a, axes=(1,2,3)) + return scipy.fft.ifftn(a, axes=(1,2,3), workers=nproc) -#?elif: # 'FFTW+BLAS' -else: # 'BLAS' - def _fftn_wrapper(a): +elif FFT_ENGINE == 'BLAS': + def _fftn_wrapper(a): # noqa mesh = a.shape[1:] return _fftn_blas(a, mesh) - def _ifftn_wrapper(a): + def _ifftn_wrapper(a): # noqa mesh = a.shape[1:] return _ifftn_blas(a, mesh)