Skip to content

Commit

Permalink
fix bug in pyscf.pbc
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebystrom committed Dec 5, 2024
1 parent 9fdc5ed commit 313ab12
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 48 deletions.
11 changes: 10 additions & 1 deletion ciderpress/pyscf/pbc/dft.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@ def make_cider_calc(
rhocut=None,
):
mlfunc = load_cider_model(mlfunc, mlfunc_format)
ks.xc = get_slxc_settings(xc, xkernel, ckernel, xmix)
ks._xc = get_slxc_settings(xc, xkernel, ckernel, xmix)
# Assign the PySCF-facing functional to be a simple SL
# functional to avoid hybrid DFT being called.
# NOTE this might need to be changed to some nicer
# approach later.
if mlfunc.settings.sl_settings.level == "MGGA":
ks.xc = "R2SCAN"
else:
ks.xc = "PBE"
new_ks = _CiderKS(
ks,
mlfunc,
Expand Down Expand Up @@ -115,6 +123,7 @@ def set_mlxc(
cls = CiderNumInt
self._numint = cls(
mlxc,
self._xc,
nldf_init,
sdmx_init,
xmix=xmix,
Expand Down
3 changes: 3 additions & 0 deletions ciderpress/pyscf/pbc/numint.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ class CiderKNumInt(CiderNumIntMixin, numint.KNumInt):
def __init__(
self,
mlxc,
slxc,
nldf_init,
sdmx_init,
xmix=1.0,
Expand All @@ -604,13 +605,15 @@ def __init__(
Args:
mlxc (MappedXC): Model for XC energy
slxc (str): Semilocal part of XC functional
nldf_init (PySCFNLDFInitializer)
sdmx_init (PySCFSDMXInitializer)
xmix (float): Mixing fraction of ML functional
rhocut (float): Low density cutoff for numerical stability
dense_mesh (3-tuple or array): Denser mesh for XC integrations
"""
self.mlxc = mlxc
self.slxc = slxc
self.xmix = xmix
self.rhocut = DEFAULT_RHOCUT if rhocut is None else rhocut
self.mol = None
Expand Down
96 changes: 49 additions & 47 deletions ciderpress/pyscf/pbc/tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#

import ctypes
import unittest

import numpy as np
from numpy.fft import fftn, ifftn, irfftn, rfftn
Expand Down Expand Up @@ -52,56 +53,57 @@ def _call_mkl_test(x, fwd, mesh):
return xr


def main():
np.random.seed(34)
meshes = [
[32, 100, 19],
[32, 100, 20],
[31, 99, 19],
[2, 2, 4],
[81, 81, 81],
[80, 80, 80],
]
class TestFFT(unittest.TestCase):
def test_fft(self):
np.random.seed(34)
meshes = [
[32, 100, 19],
[32, 100, 20],
[31, 99, 19],
[2, 2, 4],
[81, 81, 81],
[80, 80, 80],
]

for mesh in meshes:
kmesh = [mesh[0], mesh[1], mesh[2] // 2 + 1]
xrin = np.random.normal(size=mesh).astype(np.float64)
xkin1 = np.random.normal(size=kmesh)
xkin2 = np.random.normal(size=kmesh)
xkin = np.empty(kmesh, dtype=np.complex128)
xkin.real = xkin1
xkin.imag = xkin2
xkin[:] = rfftn(xrin, norm="backward")
xkc = fftn(xrin.astype(np.complex128), norm="backward")
xkin2 = xkin.copy()
if mesh[2] % 2 == 0:
xkin2[0, 0, 0] = xkin2.real[0, 0, 0]
xkin2[0, 0, -1] = xkin2.real[0, 0, -1]
for ind in [0, -1]:
for i in range(xkin2.shape[-3]):
for j in range(xkin2.shape[-2]):
tmp1 = xkin2[i, j, ind]
tmp2 = xkin2[-i, -j, ind]
xkin2[i, j, ind] = 0.5 * (tmp1 + tmp2.conj())
xkin2[-i, -j, ind] = 0.5 * (tmp1.conj() + tmp2)
for mesh in meshes:
kmesh = [mesh[0], mesh[1], mesh[2] // 2 + 1]
xrin = np.random.normal(size=mesh).astype(np.float64)
xkin1 = np.random.normal(size=kmesh)
xkin2 = np.random.normal(size=kmesh)
xkin = np.empty(kmesh, dtype=np.complex128)
xkin.real = xkin1
xkin.imag = xkin2
xkin[:] = rfftn(xrin, norm="backward")
xkc = fftn(xrin.astype(np.complex128), norm="backward")
xkin2 = xkin.copy()
if mesh[2] % 2 == 0:
xkin2[0, 0, 0] = xkin2.real[0, 0, 0]
xkin2[0, 0, -1] = xkin2.real[0, 0, -1]
for ind in [0, -1]:
for i in range(xkin2.shape[-3]):
for j in range(xkin2.shape[-2]):
tmp1 = xkin2[i, j, ind]
tmp2 = xkin2[-i, -j, ind]
xkin2[i, j, ind] = 0.5 * (tmp1 + tmp2.conj())
xkin2[-i, -j, ind] = 0.5 * (tmp1.conj() + tmp2)

xk_np = rfftn(xrin)
xk_mkl = _call_mkl_test(xrin, True, mesh)
assert (xk_np.shape == np.array(kmesh)).all()
assert (xk_mkl.shape == np.array(kmesh)).all()
xk_np = rfftn(xrin)
xk_mkl = _call_mkl_test(xrin, True, mesh)
assert (xk_np.shape == np.array(kmesh)).all()
assert (xk_mkl.shape == np.array(kmesh)).all()

xr2_np = ifftn(xkc.copy(), s=mesh, norm="forward")
xr_np = irfftn(xkin.copy(), s=mesh, norm="forward")
xr3_np = irfftn(xkin2.copy(), s=mesh, norm="forward")
xr_mkl = _call_mkl_test(xkin, False, mesh)
xr2_mkl = _call_mkl_test(xkin2, False, mesh)
assert (xr_np.shape == np.array(mesh)).all()
assert (xr_mkl.shape == np.array(mesh)).all()
assert_allclose(xr2_np.imag, 0, atol=1e-9)
assert_allclose(xr2_np, xr3_np, atol=1e-9)
assert_allclose(xr2_mkl, xr3_np, atol=1e-9)
assert_allclose(xr_mkl, xr_np, atol=1e-9)
xr2_np = ifftn(xkc.copy(), s=mesh, norm="forward")
xr_np = irfftn(xkin.copy(), s=mesh, norm="forward")
xr3_np = irfftn(xkin2.copy(), s=mesh, norm="forward")
xr_mkl = _call_mkl_test(xkin, False, mesh)
xr2_mkl = _call_mkl_test(xkin2, False, mesh)
assert (xr_np.shape == np.array(mesh)).all()
assert (xr_mkl.shape == np.array(mesh)).all()
assert_allclose(xr2_np.imag, 0, atol=1e-9)
assert_allclose(xr2_np, xr3_np, atol=1e-9)
assert_allclose(xr2_mkl, xr3_np, atol=1e-9)
assert_allclose(xr_mkl, xr_np, atol=1e-9)


if __name__ == "__main__":
main()
unittest.main()

0 comments on commit 313ab12

Please sign in to comment.