Skip to content

Commit

Permalink
Optimize fft performance (pyscf#2276)
Browse files Browse the repository at this point in the history
* Use scipy.fft module by default

* Optimize cache utilization of fft

* lint error

* fix flake F841 warning

* disable flake8
  • Loading branch information
sunqm committed Aug 29, 2024
1 parent 2ba39a3 commit 44f66cb
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export PYTHONPATH=$(pwd):$PYTHONPATH
ulimit -s 20000

mkdir -p pyscftmpdir
echo 'pbc_tools_pbc_fft_engine = "NUMPY"' > .pyscf_conf.py
echo 'pbc_tools_pbc_fft_engine = "NUMPY+BLAS"' > .pyscf_conf.py
echo "dftd3_DFTD3PATH = './pyscf/lib/deps/lib'" >> .pyscf_conf.py
echo "scf_hf_SCF_mute_chkfile = True" >> .pyscf_conf.py
echo 'TMPDIR = "./pyscftmpdir"' >> .pyscf_conf.py
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ license = { text = "Apache-2.0" }

dependencies = [
'numpy>=1.13,!=1.16,!=1.17',
'scipy!=1.5.0,!=1.5.1',
'scipy>=1.6.0',
'h5py>=2.7',
'setuptools',
]
Expand Down
115 changes: 61 additions & 54 deletions pyscf/pbc/tools/pbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,63 @@
import warnings
import ctypes
import numpy as np
import scipy
import scipy.linalg
from pyscf import lib
from pyscf.lib import logger
from pyscf.gto import ATM_SLOTS, BAS_SLOTS, ATOM_OF, PTR_COORD
from pyscf.pbc.lib.kpts_helper import get_kconserv, get_kconserv3 # noqa
from pyscf import __config__

FFT_ENGINE = getattr(__config__, 'pbc_tools_pbc_fft_engine', 'BLAS')
FFT_ENGINE = getattr(__config__, 'pbc_tools_pbc_fft_engine', 'NUMPY+BLAS')

def _fftn_blas(f, mesh):
Gx = np.fft.fftfreq(mesh[0])
Gy = np.fft.fftfreq(mesh[1])
Gz = np.fft.fftfreq(mesh[2])
expRGx = np.exp(np.einsum('x,k->xk', -2j*np.pi*np.arange(mesh[0]), Gx))
expRGy = np.exp(np.einsum('x,k->xk', -2j*np.pi*np.arange(mesh[1]), Gy))
expRGz = np.exp(np.einsum('x,k->xk', -2j*np.pi*np.arange(mesh[2]), Gz))
out = np.empty(f.shape, dtype=np.complex128)
buf = np.empty(mesh, dtype=np.complex128)
for i, fi in enumerate(f):
buf[:] = fi.reshape(mesh)
g = lib.dot(buf.reshape(mesh[0],-1).T, expRGx, c=out[i].reshape(-1,mesh[0]))
g = lib.dot(g.reshape(mesh[1],-1).T, expRGy, c=buf.reshape(-1,mesh[1]))
g = lib.dot(g.reshape(mesh[2],-1).T, expRGz, c=out[i].reshape(-1,mesh[2]))
return out.reshape(-1, *mesh)
assert f.ndim == 4
mx, my, mz = mesh
expRGx = np.exp(-2j*np.pi*np.arange(mx)[:,None] * np.fft.fftfreq(mx))
expRGy = np.exp(-2j*np.pi*np.arange(my)[:,None] * np.fft.fftfreq(my))
expRGz = np.exp(-2j*np.pi*np.arange(mz)[:,None] * np.fft.fftfreq(mz))
blksize = max(int(1e5 / (mx * my * mz)), 8) * 4
n = f.shape[0]
out = np.empty((n, mx*my*mz), dtype=np.complex128)
buf = np.empty((blksize, mx*my*mz), dtype=np.complex128)
for i0, i1 in lib.prange(0, n, blksize):
ni = i1 - i0
buf1 = buf[:ni]
out1 = out[i0:i1]
g = lib.transpose(f[i0:i1].reshape(ni,-1), out=buf1.reshape(-1,ni))
g = lib.dot(g.reshape(mx,-1).T, expRGx, c=out1.reshape(-1,mx))
g = lib.dot(g.reshape(my,-1).T, expRGy, c=buf1.reshape(-1,my))
g = lib.dot(g.reshape(mz,-1).T, expRGz, c=out1.reshape(-1,mz))
return out.reshape(n, *mesh)

def _ifftn_blas(g, mesh):
Gx = np.fft.fftfreq(mesh[0])
Gy = np.fft.fftfreq(mesh[1])
Gz = np.fft.fftfreq(mesh[2])
expRGx = np.exp(np.einsum('x,k->xk', 2j*np.pi*np.arange(mesh[0]), Gx))
expRGy = np.exp(np.einsum('x,k->xk', 2j*np.pi*np.arange(mesh[1]), Gy))
expRGz = np.exp(np.einsum('x,k->xk', 2j*np.pi*np.arange(mesh[2]), Gz))
out = np.empty(g.shape, dtype=np.complex128)
buf = np.empty(mesh, dtype=np.complex128)
for i, gi in enumerate(g):
buf[:] = gi.reshape(mesh)
f = lib.dot(buf.reshape(mesh[0],-1).T, expRGx, 1./mesh[0], c=out[i].reshape(-1,mesh[0]))
f = lib.dot(f.reshape(mesh[1],-1).T, expRGy, 1./mesh[1], c=buf.reshape(-1,mesh[1]))
f = lib.dot(f.reshape(mesh[2],-1).T, expRGz, 1./mesh[2], c=out[i].reshape(-1,mesh[2]))
return out.reshape(-1, *mesh)
assert g.ndim == 4
mx, my, mz = mesh
expRGx = np.exp(2j*np.pi*np.fft.fftfreq(mx)[:,None] * np.arange(mx))
expRGy = np.exp(2j*np.pi*np.fft.fftfreq(my)[:,None] * np.arange(my))
expRGz = np.exp(2j*np.pi*np.fft.fftfreq(mz)[:,None] * np.arange(mz))
blksize = max(int(1e5 / (mx * my * mz)), 8) * 4
n = g.shape[0]
out = np.empty((n, mx*my*mz), dtype=np.complex128)
buf = np.empty((blksize, mx*my*mz), dtype=np.complex128)
for i0, i1 in lib.prange(0, n, blksize):
ni = i1 - i0
buf1 = buf[:ni]
out1 = out[i0:i1]
f = lib.transpose(g[i0:i1].reshape(ni,-1), out=buf1.reshape(-1,ni))
f = lib.dot(f.reshape(mx,-1).T, expRGx, 1./mx, c=out1.reshape(-1,mx))
f = lib.dot(f.reshape(my,-1).T, expRGy, 1./my, c=buf1.reshape(-1,my))
f = lib.dot(f.reshape(mz,-1).T, expRGz, 1./mz, c=out1.reshape(-1,mz))
return out.reshape(n, *mesh)

nproc = lib.num_threads()

def _fftn_wrapper(a): # noqa
return scipy.fft.fftn(a, axes=(1,2,3), workers=nproc)

def _ifftn_wrapper(a): # noqa
return scipy.fft.ifftn(a, axes=(1,2,3), workers=nproc)

if FFT_ENGINE == 'FFTW':
try:
Expand Down Expand Up @@ -88,60 +105,50 @@ def _complex_fftn_fftw(f, mesh, func):
ctypes.c_int(rank))
return out

def _fftn_wrapper(a):
def _fftn_wrapper(a): # noqa
mesh = a.shape[1:]
return _complex_fftn_fftw(a, mesh, 'fft')
def _ifftn_wrapper(a):
def _ifftn_wrapper(a): # noqa
mesh = a.shape[1:]
return _complex_fftn_fftw(a, mesh, 'ifft')

elif FFT_ENGINE == 'PYFFTW':
# pyfftw is slower than np.fft in most cases
# Note: pyfftw is likely slower than scipy.fft in multi-threading environments
try:
import pyfftw
pyfftw.config.PLANNER_EFFORT = 'FFTW_MEASURE'
pyfftw.interfaces.cache.enable()
nproc = lib.num_threads()
def _fftn_wrapper(a):
def _fftn_wrapper(a): # noqa
return pyfftw.interfaces.numpy_fft.fftn(a, axes=(1,2,3), threads=nproc)
def _ifftn_wrapper(a):
def _ifftn_wrapper(a): # noqa
return pyfftw.interfaces.numpy_fft.ifftn(a, axes=(1,2,3), threads=nproc)
except ImportError:
def _fftn_wrapper(a):
return np.fft.fftn(a, axes=(1,2,3))
def _ifftn_wrapper(a):
return np.fft.ifftn(a, axes=(1,2,3))

elif FFT_ENGINE == 'NUMPY':
def _fftn_wrapper(a):
return np.fft.fftn(a, axes=(1,2,3))
def _ifftn_wrapper(a):
return np.fft.ifftn(a, axes=(1,2,3))
print('PyFFTW not installed. SciPy fft module will be used.')

elif FFT_ENGINE == 'NUMPY+BLAS':
_EXCLUDE = [17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79,
83, 89, 97,101,103,107,109,113,127,131,137,139,149,151,157,163,
167,173,179,181,191,193,197,199,211,223,227,229,233,239,241,251,
257,263,269,271,277,281,283,293]
_EXCLUDE = set(_EXCLUDE + [n*2 for n in _EXCLUDE] + [n*3 for n in _EXCLUDE])
def _fftn_wrapper(a):
_EXCLUDE = set(_EXCLUDE + [n*2 for n in _EXCLUDE[:30]] + [n*3 for n in _EXCLUDE[:20]])
def _fftn_wrapper(a): # noqa
mesh = a.shape[1:]
if mesh[0] in _EXCLUDE and mesh[1] in _EXCLUDE and mesh[2] in _EXCLUDE:
return _fftn_blas(a, mesh)
else:
return np.fft.fftn(a, axes=(1,2,3))
def _ifftn_wrapper(a):
return scipy.fft.fftn(a, axes=(1,2,3), workers=nproc)
def _ifftn_wrapper(a): # noqa
mesh = a.shape[1:]
if mesh[0] in _EXCLUDE and mesh[1] in _EXCLUDE and mesh[2] in _EXCLUDE:
return _ifftn_blas(a, mesh)
else:
return np.fft.ifftn(a, axes=(1,2,3))
return scipy.fft.ifftn(a, axes=(1,2,3), workers=nproc)

#?elif: # 'FFTW+BLAS'
else: # 'BLAS'
def _fftn_wrapper(a):
elif FFT_ENGINE == 'BLAS':
def _fftn_wrapper(a): # noqa
mesh = a.shape[1:]
return _fftn_blas(a, mesh)
def _ifftn_wrapper(a):
def _ifftn_wrapper(a): # noqa
mesh = a.shape[1:]
return _ifftn_blas(a, mesh)

Expand Down

0 comments on commit 44f66cb

Please sign in to comment.