Skip to content

Commit

Permalink
Refactor VHFOpt, to manage the memory of CVHFOpt in Python
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Sep 30, 2023
1 parent cf99035 commit 97fac95
Show file tree
Hide file tree
Showing 20 changed files with 446 additions and 421 deletions.
26 changes: 10 additions & 16 deletions pyscf/df/df_jk.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,23 +338,21 @@ def get_j(dfobj, dm, hermi=1, direct_scf_tol=1e-13):
mol = dfobj.mol
if dfobj._vjopt is None:
dfobj.auxmol = auxmol = addons.make_auxmol(mol, dfobj.auxbasis)
opt = _vhf.VHFOpt(mol, 'int3c2e', 'CVHFnr3c2e_schwarz_cond')
opt.direct_scf_tol = direct_scf_tol
opt = _vhf._VHFOpt(mol, 'int3c2e', 'CVHFnr3c2e_schwarz_cond',
dmcondname='CVHFnr_dm_cond',
direct_scf_tol=direct_scf_tol)

# q_cond part 1: the regular int2e (ij|ij) for mol's basis
opt.init_cvhf_direct(mol, 'int2e', 'CVHFsetnr_direct_scf')
mol_q_cond = lib.frompointer(opt._this.contents.q_cond, mol.nbas**2)
opt.init_cvhf_direct(mol, 'int2e', 'CVHFnr_int2e_q_cond')

# Update q_cond to include the 2e-integrals (auxmol|auxmol)
j2c = auxmol.intor('int2c2e', hermi=1)
j2c_diag = numpy.sqrt(abs(j2c.diagonal()))
aux_loc = auxmol.ao_loc
aux_q_cond = [j2c_diag[i0:i1].max()
for i0, i1 in zip(aux_loc[:-1], aux_loc[1:])]
q_cond = numpy.hstack((mol_q_cond, aux_q_cond))
fsetqcond = _vhf.libcvhf.CVHFset_q_cond
fsetqcond(opt._this, q_cond.ctypes.data_as(ctypes.c_void_p),
ctypes.c_int(q_cond.size))
q_cond = numpy.hstack((opt.q_cond.ravel(), aux_q_cond))
opt.q_cond = q_cond

try:
opt.j2c = j2c = scipy.linalg.cho_factor(j2c, lower=True)
Expand Down Expand Up @@ -388,8 +386,7 @@ def get_j(dfobj, dm, hermi=1, direct_scf_tol=1e-13):
nbas = mol.nbas
nbas1 = mol.nbas + dfobj.auxmol.nbas
shls_slice = (0, nbas, 0, nbas, nbas, nbas1, nbas1, nbas1+1)
with lib.temporary_env(opt, prescreen='CVHFnr3c2e_vj_pass1_prescreen',
_dmcondname='CVHFsetnr_direct_scf_dm'):
with lib.temporary_env(opt, prescreen='CVHFnr3c2e_vj_pass1_prescreen'):
jaux = jk.get_jk(fakemol, dm, ['ijkl,ji->kl']*n_dm, 'int3c2e',
aosym='s2ij', hermi=0, shls_slice=shls_slice,
vhfopt=opt)
Expand All @@ -408,17 +405,14 @@ def get_j(dfobj, dm, hermi=1, direct_scf_tol=1e-13):
# Next compute the Coulomb matrix
# j3c = fauxe2(mol, auxmol)
# vj = numpy.einsum('ijk,k->ij', j3c, rho)
# temporarily set "_dmcondname=None" to skip the call to set_dm method.
with lib.temporary_env(opt, prescreen='CVHFnr3c2e_vj_pass2_prescreen',
_dmcondname=None):
# CVHFnr3c2e_vj_pass2_prescreen requires custom dm_cond
aux_loc = dfobj.auxmol.ao_loc
dm_cond = [abs(rho[:,:,i0:i1]).max()
for i0, i1 in zip(aux_loc[:-1], aux_loc[1:])]
dm_cond = numpy.array(dm_cond)
fsetcond = _vhf.libcvhf.CVHFset_dm_cond
fsetcond(opt._this, dm_cond.ctypes.data_as(ctypes.c_void_p),
ctypes.c_int(dm_cond.size))

opt.dm_cond = numpy.array(dm_cond)
vj = jk.get_jk(fakemol, rho, ['ijkl,lk->ij']*n_dm, 'int3c2e',
aosym='s2ij', hermi=1, shls_slice=shls_slice,
vhfopt=opt)
Expand Down Expand Up @@ -546,7 +540,7 @@ def fjk(dm):
energy = method.scf()
print(energy, -76.0807386770) # normal DHF energy is -76.0815679438127

method = density_fit(pyscf.scf.UKS(mol), 'weigend')
method = density_fit(pyscf.scf.UKS(mol), 'weigend', only_dfj = True)
energy = method.scf()
print(energy, -75.8547753298)

Expand Down
40 changes: 19 additions & 21 deletions pyscf/grad/rhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,28 +150,26 @@ def get_jk(mol, dm):
'''J = ((-nabla i) j| kl) D_lk
K = ((-nabla i) j| kl) D_jk
'''
vhfopt = _vhf.VHFOpt(mol, 'int2e_ip1ip2', 'CVHFgrad_jk_prescreen',
'CVHFgrad_jk_direct_scf')
dm = numpy.asarray(dm, order='C')
if dm.ndim == 3:
n_dm = dm.shape[0]
else:
n_dm = 1
libcvhf = _vhf.libcvhf
vhfopt = _vhf._VHFOpt(mol, 'int2e_ip1', 'CVHFgrad_jk_prescreen',
dmcondname='CVHFnr_dm_cond1')
ao_loc = mol.ao_loc_nr()
fsetdm = getattr(_vhf.libcvhf, 'CVHFgrad_jk_direct_scf_dm')
fsetdm(vhfopt._this,
dm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(n_dm),
ao_loc.ctypes.data_as(ctypes.c_void_p),
mol._atm.ctypes.data_as(ctypes.c_void_p), mol.natm,
mol._bas.ctypes.data_as(ctypes.c_void_p), mol.nbas,
mol._env.ctypes.data_as(ctypes.c_void_p))

# Update the vhfopt's attributes intor. Function direct_mapdm needs
# vhfopt._intor and vhfopt._cintopt to compute J/K. intor was initialized
# as int2e_ip1ip2. It should be int2e_ip1
vhfopt._intor = intor = mol._add_suffix('int2e_ip1')
vhfopt._cintopt = None

nbas = mol.nbas
q_cond = numpy.empty((2, nbas, nbas))
with mol.with_integral_screen(vhfopt.direct_scf_tol**2):
libcvhf.CVHFnr_int2e_pp_q_cond(
getattr(libcvhf, mol._add_suffix('int2e_ip1ip2')),
lib.c_null_ptr(), q_cond[0].ctypes,
ao_loc.ctypes, mol._atm.ctypes, ctypes.c_int(mol.natm),
mol._bas.ctypes, ctypes.c_int(nbas), mol._env.ctypes)
libcvhf.CVHFnr_int2e_q_cond(
getattr(libcvhf, mol._add_suffix('int2e')),
lib.c_null_ptr(), q_cond[1].ctypes,
ao_loc.ctypes, mol._atm.ctypes, ctypes.c_int(mol.natm),
mol._bas.ctypes, ctypes.c_int(nbas), mol._env.ctypes)
vhfopt.q_cond = q_cond

intor = mol._add_suffix('int2e_ip1')
vj, vk = _vhf.direct_mapdm(intor, # (nabla i,j|k,l)
's2kl', # ip1_sph has k>=l,
('lk->s1ij', 'jk->s1il'),
Expand Down
45 changes: 24 additions & 21 deletions pyscf/hessian/rhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,30 +182,33 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
return e1, ej, ek

def _make_vhfopt(mol, dms, key, vhf_intor):
if not hasattr(_vhf.libcvhf, vhf_intor):
libcvhf = _vhf.libcvhf
if not hasattr(libcvhf, vhf_intor):
return None

vhfopt = _vhf.VHFOpt(mol, vhf_intor, 'CVHF'+key+'_prescreen',
'CVHF'+key+'_direct_scf')
dms = numpy.asarray(dms, order='C')
if dms.ndim == 3:
n_dm = dms.shape[0]
else:
n_dm = 1
vhfopt = _vhf._VHFOpt(mol, 'int2e_'+key, 'CVHF'+key+'_prescreen',
dmcondname='CVHFnr_dm_cond1')
ao_loc = mol.ao_loc_nr()
fsetdm = getattr(_vhf.libcvhf, 'CVHF'+key+'_direct_scf_dm')
fsetdm(vhfopt._this,
dms.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(n_dm),
ao_loc.ctypes.data_as(ctypes.c_void_p),
mol._atm.ctypes.data_as(ctypes.c_void_p), mol.natm,
mol._bas.ctypes.data_as(ctypes.c_void_p), mol.nbas,
mol._env.ctypes.data_as(ctypes.c_void_p))

# Update the vhfopt's attributes intor. Function direct_mapdm needs
# vhfopt._intor and vhfopt._cintopt to compute J/K.
if vhf_intor != 'int2e_'+key:
vhfopt._intor = mol._add_suffix('int2e_'+key)
vhfopt._cintopt = None
nbas = mol.nbas
q_cond = numpy.empty((2, nbas, nbas))
with mol.with_integral_screen(vhfopt.direct_scf_tol**2):
if vhf_intor == 'int2e_ip1ip2':
fqcond = libcvhf.CVHFnr_int2e_pp_q_cond
elif vhf_intor in ('int2e_ipip1ipip2', 'int2e_ipvip1ipvip2'):
fqcond = libcvhf.CVHFnr_int2e_pppp_q_cond
else:
raise NotImplementedError(vhf_intor)
fqcond(
getattr(libcvhf, mol._add_suffix(vhf_intor)),
lib.c_null_ptr(), q_cond[0].ctypes,
ao_loc.ctypes, mol._atm.ctypes, ctypes.c_int(mol.natm),
mol._bas.ctypes, ctypes.c_int(nbas), mol._env.ctypes)
libcvhf.CVHFnr_int2e_q_cond(
getattr(libcvhf, mol._add_suffix('int2e')),
lib.c_null_ptr(), q_cond[1].ctypes,
ao_loc.ctypes, mol._atm.ctypes, ctypes.c_int(mol.natm),
mol._bas.ctypes, ctypes.c_int(nbas), mol._env.ctypes)
vhfopt.q_cond = q_cond
return vhfopt


Expand Down
23 changes: 18 additions & 5 deletions pyscf/lib/pbc/nr_direct.c
Original file line number Diff line number Diff line change
Expand Up @@ -1033,9 +1033,9 @@ void PBCVHF_direct_drv_nodddd(
free(qidx_iijj);
}

void PBCVHFsetnr_direct_scf(int (*intor)(), CINTOpt *cintopt, int16_t *qindex,
int *ao_loc, int *atm, int natm,
int *bas, int nbas, double *env)
void PBCVHFnr_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, int16_t *qindex,
int *ao_loc, int *atm, int natm,
int *bas, int nbas, double *env)
{
size_t Nbas = nbas;
size_t Nbas2 = Nbas * Nbas;
Expand Down Expand Up @@ -1109,8 +1109,15 @@ void PBCVHFsetnr_direct_scf(int (*intor)(), CINTOpt *cintopt, int16_t *qindex,
}
}

void PBCVHFsetnr_sindex(int16_t *sindex, int *atm, int natm,
int *bas, int nbas, double *env)
void PBCVHFsetnr_direct_scf(int (*intor)(), CINTOpt *cintopt, int16_t *qindex,
int *ao_loc, int *atm, int natm,
int *bas, int nbas, double *env)
{
PBCVHFnr_int2e_q_cond(intor, cintopt, qindex, ao_loc, atm, natm, bas, nbas, env);
}

void PBCVHFnr_sindex(int16_t *sindex, int *atm, int natm,
int *bas, int nbas, double *env)
{
size_t Nbas = nbas;
size_t Nbas1 = nbas + 1;
Expand Down Expand Up @@ -1220,3 +1227,9 @@ void PBCVHFsetnr_sindex(int16_t *sindex, int *atm, int natm,
free(exps);
free(exps_group_loc);
}

void PBCVHFsetnr_sindex(int16_t *sindex, int *atm, int natm,
int *bas, int nbas, double *env)
{
PBCVHFnr_sindex(sindex, atm, natm, bas, nbas, env);
}
Loading

0 comments on commit 97fac95

Please sign in to comment.