From 97fac959d12f3c84b99f82c63f9f476066d29c69 Mon Sep 17 00:00:00 2001 From: Qiming Sun Date: Fri, 29 Sep 2023 23:16:15 -0700 Subject: [PATCH] Refactor VHFOpt, to manage the memory of CVHFOpt in Python --- pyscf/df/df_jk.py | 26 ++---- pyscf/grad/rhf.py | 40 ++++---- pyscf/hessian/rhf.py | 45 ++++----- pyscf/lib/pbc/nr_direct.c | 23 ++++- pyscf/lib/vhf/hessian_screen.c | 134 +++++++++++++-------------- pyscf/lib/vhf/nr_sgx_direct.c | 61 ++++++++----- pyscf/lib/vhf/nr_sr_vhf.c | 14 ++- pyscf/lib/vhf/optimizer.c | 22 +++++ pyscf/lib/vhf/optimizer.h | 8 ++ pyscf/lib/vhf/rkb_screen.c | 125 +++++++++++++------------ pyscf/pbc/scf/rsjk.py | 4 +- pyscf/scf/_vhf.py | 161 ++++++++++++++++----------------- pyscf/scf/dhf.py | 78 ++++++++++------ pyscf/scf/hf.py | 7 +- pyscf/scf/test/test_vhf.py | 5 +- pyscf/sgx/sgx.py | 38 ++------ pyscf/sgx/sgx_jk.py | 46 +--------- pyscf/sgx/test/test_sgx.py | 13 ++- pyscf/sgx/test/test_sgx_jk.py | 6 +- pyscf/x2c/x2c.py | 11 +-- 20 files changed, 446 insertions(+), 421 deletions(-) diff --git a/pyscf/df/df_jk.py b/pyscf/df/df_jk.py index e5f0ab7b5a..e36de3ea74 100644 --- a/pyscf/df/df_jk.py +++ b/pyscf/df/df_jk.py @@ -338,12 +338,12 @@ def get_j(dfobj, dm, hermi=1, direct_scf_tol=1e-13): mol = dfobj.mol if dfobj._vjopt is None: dfobj.auxmol = auxmol = addons.make_auxmol(mol, dfobj.auxbasis) - opt = _vhf.VHFOpt(mol, 'int3c2e', 'CVHFnr3c2e_schwarz_cond') - opt.direct_scf_tol = direct_scf_tol + opt = _vhf._VHFOpt(mol, 'int3c2e', 'CVHFnr3c2e_schwarz_cond', + dmcondname='CVHFnr_dm_cond', + direct_scf_tol=direct_scf_tol) # q_cond part 1: the regular int2e (ij|ij) for mol's basis - opt.init_cvhf_direct(mol, 'int2e', 'CVHFsetnr_direct_scf') - mol_q_cond = lib.frompointer(opt._this.contents.q_cond, mol.nbas**2) + opt.init_cvhf_direct(mol, 'int2e', 'CVHFnr_int2e_q_cond') # Update q_cond to include the 2e-integrals (auxmol|auxmol) j2c = auxmol.intor('int2c2e', hermi=1) @@ -351,10 +351,8 @@ def get_j(dfobj, dm, hermi=1, direct_scf_tol=1e-13): aux_loc = auxmol.ao_loc aux_q_cond = [j2c_diag[i0:i1].max() for i0, i1 in zip(aux_loc[:-1], aux_loc[1:])] - q_cond = numpy.hstack((mol_q_cond, aux_q_cond)) - fsetqcond = _vhf.libcvhf.CVHFset_q_cond - fsetqcond(opt._this, q_cond.ctypes.data_as(ctypes.c_void_p), - ctypes.c_int(q_cond.size)) + q_cond = numpy.hstack((opt.q_cond.ravel(), aux_q_cond)) + opt.q_cond = q_cond try: opt.j2c = j2c = scipy.linalg.cho_factor(j2c, lower=True) @@ -388,8 +386,7 @@ def get_j(dfobj, dm, hermi=1, direct_scf_tol=1e-13): nbas = mol.nbas nbas1 = mol.nbas + dfobj.auxmol.nbas shls_slice = (0, nbas, 0, nbas, nbas, nbas1, nbas1, nbas1+1) - with lib.temporary_env(opt, prescreen='CVHFnr3c2e_vj_pass1_prescreen', - _dmcondname='CVHFsetnr_direct_scf_dm'): + with lib.temporary_env(opt, prescreen='CVHFnr3c2e_vj_pass1_prescreen'): jaux = jk.get_jk(fakemol, dm, ['ijkl,ji->kl']*n_dm, 'int3c2e', aosym='s2ij', hermi=0, shls_slice=shls_slice, vhfopt=opt) @@ -408,17 +405,14 @@ def get_j(dfobj, dm, hermi=1, direct_scf_tol=1e-13): # Next compute the Coulomb matrix # j3c = fauxe2(mol, auxmol) # vj = numpy.einsum('ijk,k->ij', j3c, rho) + # temporarily set "_dmcondname=None" to skip the call to set_dm method. with lib.temporary_env(opt, prescreen='CVHFnr3c2e_vj_pass2_prescreen', _dmcondname=None): # CVHFnr3c2e_vj_pass2_prescreen requires custom dm_cond aux_loc = dfobj.auxmol.ao_loc dm_cond = [abs(rho[:,:,i0:i1]).max() for i0, i1 in zip(aux_loc[:-1], aux_loc[1:])] - dm_cond = numpy.array(dm_cond) - fsetcond = _vhf.libcvhf.CVHFset_dm_cond - fsetcond(opt._this, dm_cond.ctypes.data_as(ctypes.c_void_p), - ctypes.c_int(dm_cond.size)) - + opt.dm_cond = numpy.array(dm_cond) vj = jk.get_jk(fakemol, rho, ['ijkl,lk->ij']*n_dm, 'int3c2e', aosym='s2ij', hermi=1, shls_slice=shls_slice, vhfopt=opt) @@ -546,7 +540,7 @@ def fjk(dm): energy = method.scf() print(energy, -76.0807386770) # normal DHF energy is -76.0815679438127 - method = density_fit(pyscf.scf.UKS(mol), 'weigend') + method = density_fit(pyscf.scf.UKS(mol), 'weigend', only_dfj = True) energy = method.scf() print(energy, -75.8547753298) diff --git a/pyscf/grad/rhf.py b/pyscf/grad/rhf.py index 6d95dc6de2..23eee431c0 100644 --- a/pyscf/grad/rhf.py +++ b/pyscf/grad/rhf.py @@ -150,28 +150,26 @@ def get_jk(mol, dm): '''J = ((-nabla i) j| kl) D_lk K = ((-nabla i) j| kl) D_jk ''' - vhfopt = _vhf.VHFOpt(mol, 'int2e_ip1ip2', 'CVHFgrad_jk_prescreen', - 'CVHFgrad_jk_direct_scf') - dm = numpy.asarray(dm, order='C') - if dm.ndim == 3: - n_dm = dm.shape[0] - else: - n_dm = 1 + libcvhf = _vhf.libcvhf + vhfopt = _vhf._VHFOpt(mol, 'int2e_ip1', 'CVHFgrad_jk_prescreen', + dmcondname='CVHFnr_dm_cond1') ao_loc = mol.ao_loc_nr() - fsetdm = getattr(_vhf.libcvhf, 'CVHFgrad_jk_direct_scf_dm') - fsetdm(vhfopt._this, - dm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(n_dm), - ao_loc.ctypes.data_as(ctypes.c_void_p), - mol._atm.ctypes.data_as(ctypes.c_void_p), mol.natm, - mol._bas.ctypes.data_as(ctypes.c_void_p), mol.nbas, - mol._env.ctypes.data_as(ctypes.c_void_p)) - - # Update the vhfopt's attributes intor. Function direct_mapdm needs - # vhfopt._intor and vhfopt._cintopt to compute J/K. intor was initialized - # as int2e_ip1ip2. It should be int2e_ip1 - vhfopt._intor = intor = mol._add_suffix('int2e_ip1') - vhfopt._cintopt = None - + nbas = mol.nbas + q_cond = numpy.empty((2, nbas, nbas)) + with mol.with_integral_screen(vhfopt.direct_scf_tol**2): + libcvhf.CVHFnr_int2e_pp_q_cond( + getattr(libcvhf, mol._add_suffix('int2e_ip1ip2')), + lib.c_null_ptr(), q_cond[0].ctypes, + ao_loc.ctypes, mol._atm.ctypes, ctypes.c_int(mol.natm), + mol._bas.ctypes, ctypes.c_int(nbas), mol._env.ctypes) + libcvhf.CVHFnr_int2e_q_cond( + getattr(libcvhf, mol._add_suffix('int2e')), + lib.c_null_ptr(), q_cond[1].ctypes, + ao_loc.ctypes, mol._atm.ctypes, ctypes.c_int(mol.natm), + mol._bas.ctypes, ctypes.c_int(nbas), mol._env.ctypes) + vhfopt.q_cond = q_cond + + intor = mol._add_suffix('int2e_ip1') vj, vk = _vhf.direct_mapdm(intor, # (nabla i,j|k,l) 's2kl', # ip1_sph has k>=l, ('lk->s1ij', 'jk->s1il'), diff --git a/pyscf/hessian/rhf.py b/pyscf/hessian/rhf.py index 6e27ba5bf3..13675f2e0a 100644 --- a/pyscf/hessian/rhf.py +++ b/pyscf/hessian/rhf.py @@ -182,30 +182,33 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None, return e1, ej, ek def _make_vhfopt(mol, dms, key, vhf_intor): - if not hasattr(_vhf.libcvhf, vhf_intor): + libcvhf = _vhf.libcvhf + if not hasattr(libcvhf, vhf_intor): return None - vhfopt = _vhf.VHFOpt(mol, vhf_intor, 'CVHF'+key+'_prescreen', - 'CVHF'+key+'_direct_scf') - dms = numpy.asarray(dms, order='C') - if dms.ndim == 3: - n_dm = dms.shape[0] - else: - n_dm = 1 + vhfopt = _vhf._VHFOpt(mol, 'int2e_'+key, 'CVHF'+key+'_prescreen', + dmcondname='CVHFnr_dm_cond1') ao_loc = mol.ao_loc_nr() - fsetdm = getattr(_vhf.libcvhf, 'CVHF'+key+'_direct_scf_dm') - fsetdm(vhfopt._this, - dms.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(n_dm), - ao_loc.ctypes.data_as(ctypes.c_void_p), - mol._atm.ctypes.data_as(ctypes.c_void_p), mol.natm, - mol._bas.ctypes.data_as(ctypes.c_void_p), mol.nbas, - mol._env.ctypes.data_as(ctypes.c_void_p)) - - # Update the vhfopt's attributes intor. Function direct_mapdm needs - # vhfopt._intor and vhfopt._cintopt to compute J/K. - if vhf_intor != 'int2e_'+key: - vhfopt._intor = mol._add_suffix('int2e_'+key) - vhfopt._cintopt = None + nbas = mol.nbas + q_cond = numpy.empty((2, nbas, nbas)) + with mol.with_integral_screen(vhfopt.direct_scf_tol**2): + if vhf_intor == 'int2e_ip1ip2': + fqcond = libcvhf.CVHFnr_int2e_pp_q_cond + elif vhf_intor in ('int2e_ipip1ipip2', 'int2e_ipvip1ipvip2'): + fqcond = libcvhf.CVHFnr_int2e_pppp_q_cond + else: + raise NotImplementedError(vhf_intor) + fqcond( + getattr(libcvhf, mol._add_suffix(vhf_intor)), + lib.c_null_ptr(), q_cond[0].ctypes, + ao_loc.ctypes, mol._atm.ctypes, ctypes.c_int(mol.natm), + mol._bas.ctypes, ctypes.c_int(nbas), mol._env.ctypes) + libcvhf.CVHFnr_int2e_q_cond( + getattr(libcvhf, mol._add_suffix('int2e')), + lib.c_null_ptr(), q_cond[1].ctypes, + ao_loc.ctypes, mol._atm.ctypes, ctypes.c_int(mol.natm), + mol._bas.ctypes, ctypes.c_int(nbas), mol._env.ctypes) + vhfopt.q_cond = q_cond return vhfopt diff --git a/pyscf/lib/pbc/nr_direct.c b/pyscf/lib/pbc/nr_direct.c index 47bf9c561d..0aa088a369 100644 --- a/pyscf/lib/pbc/nr_direct.c +++ b/pyscf/lib/pbc/nr_direct.c @@ -1033,9 +1033,9 @@ void PBCVHF_direct_drv_nodddd( free(qidx_iijj); } -void PBCVHFsetnr_direct_scf(int (*intor)(), CINTOpt *cintopt, int16_t *qindex, - int *ao_loc, int *atm, int natm, - int *bas, int nbas, double *env) +void PBCVHFnr_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, int16_t *qindex, + int *ao_loc, int *atm, int natm, + int *bas, int nbas, double *env) { size_t Nbas = nbas; size_t Nbas2 = Nbas * Nbas; @@ -1109,8 +1109,15 @@ void PBCVHFsetnr_direct_scf(int (*intor)(), CINTOpt *cintopt, int16_t *qindex, } } -void PBCVHFsetnr_sindex(int16_t *sindex, int *atm, int natm, - int *bas, int nbas, double *env) +void PBCVHFsetnr_direct_scf(int (*intor)(), CINTOpt *cintopt, int16_t *qindex, + int *ao_loc, int *atm, int natm, + int *bas, int nbas, double *env) +{ + PBCVHFnr_int2e_q_cond(intor, cintopt, qindex, ao_loc, atm, natm, bas, nbas, env); +} + +void PBCVHFnr_sindex(int16_t *sindex, int *atm, int natm, + int *bas, int nbas, double *env) { size_t Nbas = nbas; size_t Nbas1 = nbas + 1; @@ -1220,3 +1227,9 @@ void PBCVHFsetnr_sindex(int16_t *sindex, int *atm, int natm, free(exps); free(exps_group_loc); } + +void PBCVHFsetnr_sindex(int16_t *sindex, int *atm, int natm, + int *bas, int nbas, double *env) +{ + PBCVHFnr_sindex(sindex, atm, natm, bas, nbas, env); +} diff --git a/pyscf/lib/vhf/hessian_screen.c b/pyscf/lib/vhf/hessian_screen.c index bc55a1b409..51479695bc 100644 --- a/pyscf/lib/vhf/hessian_screen.c +++ b/pyscf/lib/vhf/hessian_screen.c @@ -67,32 +67,16 @@ int CVHFgrad_jk_prescreen(int *shls, CVHFOpt *opt, || ( opt->dm_cond[j*n+l] > dmin)); } -void CVHFgrad_jk_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, +void CVHFnr_int2e_pp_q_cond(int (*intor)(), CINTOpt *cintopt, double *q_cond, int *ao_loc, int *atm, int natm, int *bas, int nbas, double *env) { - if (opt->q_cond != NULL) { - free(opt->q_cond); - } - nbas = opt->nbas; - size_t Nbas = nbas; - size_t Nbas2 = Nbas * Nbas; - // First n*n elements for derivatives, the next n*n elements for regular ERIs - opt->q_cond = (double *)malloc(sizeof(double) * Nbas2*2); - - if (ao_loc[nbas] == CINTtot_cgto_spheric(bas, nbas)) { - CVHFset_int2e_q_cond(int2e_sph, NULL, opt->q_cond+Nbas2, ao_loc, - atm, natm, bas, nbas, env); - } else { - CVHFset_int2e_q_cond(int2e_cart, NULL, opt->q_cond+Nbas2, ao_loc, - atm, natm, bas, nbas, env); - } - + int nbas2 = nbas * nbas; int shls_slice[] = {0, nbas}; const int cache_size = GTOmax_cache_size(intor, shls_slice, 1, atm, natm, bas, nbas, env); #pragma omp parallel \ - shared(opt, intor, cintopt, ao_loc, atm, natm, bas, nbas, env) + shared(intor, cintopt, ao_loc, atm, natm, bas, nbas, env) { double qtmp; int i, j, iijj, di, dj, ish, jsh; @@ -108,9 +92,9 @@ void CVHFgrad_jk_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, double *bufx = buf; double *bufy, *bufz; #pragma omp for schedule(dynamic, 4) - for (ij = 0; ij < Nbas2; ij++) { - ish = ij / Nbas; - jsh = ij - ish * Nbas; + for (ij = 0; ij < nbas2; ij++) { + ish = ij / nbas; + jsh = ij - ish * nbas; di = ao_loc[ish+1] - ao_loc[ish]; dj = ao_loc[jsh+1] - ao_loc[jsh]; shls[0] = ish; @@ -131,13 +115,37 @@ void CVHFgrad_jk_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, } } qtmp = sqrt(qtmp); } - opt->q_cond[ish*nbas+jsh] = qtmp; + q_cond[ish*nbas+jsh] = qtmp; } free(buf); free(cache); } } +void CVHFgrad_jk_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, + int *ao_loc, int *atm, int natm, + int *bas, int nbas, double *env) +{ + if (opt->q_cond != NULL) { + free(opt->q_cond); + } + nbas = opt->nbas; + size_t Nbas = nbas; + size_t Nbas2 = Nbas * Nbas; + // First n*n elements for derivatives, the next n*n elements for regular ERIs + opt->q_cond = (double *)malloc(sizeof(double) * Nbas2*2); + + if (ao_loc[nbas] == CINTtot_cgto_spheric(bas, nbas)) { + CVHFnr_int2e_q_cond(int2e_sph, NULL, opt->q_cond+Nbas2, ao_loc, + atm, natm, bas, nbas, env); + } else { + CVHFnr_int2e_q_cond(int2e_cart, NULL, opt->q_cond+Nbas2, ao_loc, + atm, natm, bas, nbas, env); + } + CVHFnr_int2e_pp_q_cond(intor, cintopt, opt->q_cond, ao_loc, + atm, natm, bas, nbas, env); +} + void CVHFgrad_jk_direct_scf_dm(CVHFOpt *opt, double *dm, int nset, int *ao_loc, int *atm, int natm, int *bas, int nbas, double *env) { @@ -145,27 +153,8 @@ void CVHFgrad_jk_direct_scf_dm(CVHFOpt *opt, double *dm, int nset, int *ao_loc, free(opt->dm_cond); } nbas = opt->nbas; - size_t Nbas = nbas; opt->dm_cond = (double *)malloc(sizeof(double) * nbas*nbas); - NPdset0(opt->dm_cond, Nbas * Nbas); - - const size_t nao = ao_loc[nbas]; - double dmax; - int i, j, ish, jsh; - int iset; - double *pdm; - for (ish = 0; ish < nbas; ish++) { - for (jsh = 0; jsh < nbas; jsh++) { - dmax = 0; - for (iset = 0; iset < nset; iset++) { - pdm = dm + nao*nao*iset; - for (i = ao_loc[ish]; i < ao_loc[ish+1]; i++) { - for (j = ao_loc[jsh]; j < ao_loc[jsh+1]; j++) { - dmax = MAX(dmax, fabs(pdm[i*nao+j])); - } } - } - opt->dm_cond[ish*Nbas+jsh] = dmax; - } } + CVHFnr_dm_cond1(opt->dm_cond, dm, nset, ao_loc, atm, natm, bas, nbas, env); } @@ -246,33 +235,16 @@ int CVHFipip1_prescreen(int *shls, CVHFOpt *opt, || ( opt->dm_cond[j*n+l] > dmin)); } - -void CVHFipip1_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, - int *ao_loc, int *atm, int natm, - int *bas, int nbas, double *env) +void CVHFnr_int2e_pppp_q_cond(int (*intor)(), CINTOpt *cintopt, double *q_cond, + int *ao_loc, int *atm, int natm, + int *bas, int nbas, double *env) { - if (opt->q_cond != NULL) { - free(opt->q_cond); - } - nbas = opt->nbas; - size_t Nbas = nbas; - size_t Nbas2 = Nbas * Nbas; - // First n*n elements for derivatives, the next n*n elements for regular ERIs - opt->q_cond = (double *)malloc(sizeof(double) * nbas*nbas*2); - - if (ao_loc[nbas] == CINTtot_cgto_spheric(bas, nbas)) { - CVHFset_int2e_q_cond(int2e_sph, NULL, opt->q_cond+Nbas2, ao_loc, - atm, natm, bas, nbas, env); - } else { - CVHFset_int2e_q_cond(int2e_cart, NULL, opt->q_cond+Nbas2, ao_loc, - atm, natm, bas, nbas, env); - } - + int nbas2 = nbas * nbas; int shls_slice[] = {0, nbas}; const int cache_size = GTOmax_cache_size(intor, shls_slice, 1, atm, natm, bas, nbas, env); #pragma omp parallel \ - shared(opt, intor, cintopt, ao_loc, atm, natm, bas, nbas, env) + shared(intor, cintopt, ao_loc, atm, natm, bas, nbas, env) { double qtmp; int i, j, iijj, di, dj, ish, jsh; @@ -288,9 +260,9 @@ void CVHFipip1_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, double *bufxx = buf; double *bufxy, *bufxz, *bufyx, *bufyy, *bufyz, *bufzx, *bufzy, *bufzz; #pragma omp for schedule(dynamic, 4) - for (ij = 0; ij < Nbas2; ij++) { - ish = ij / Nbas; - jsh = ij - ish * Nbas; + for (ij = 0; ij < nbas2; ij++) { + ish = ij / nbas; + jsh = ij - ish * nbas; di = ao_loc[ish+1] - ao_loc[ish]; dj = ao_loc[jsh+1] - ao_loc[jsh]; shls[0] = ish; @@ -324,13 +296,37 @@ void CVHFipip1_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, } } qtmp = sqrt(qtmp); } - opt->q_cond[ish*nbas+jsh] = qtmp; + q_cond[ish*nbas+jsh] = qtmp; } free(buf); free(cache); } } +void CVHFipip1_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, + int *ao_loc, int *atm, int natm, + int *bas, int nbas, double *env) +{ + if (opt->q_cond != NULL) { + free(opt->q_cond); + } + nbas = opt->nbas; + size_t Nbas = nbas; + size_t Nbas2 = Nbas * Nbas; + // First n*n elements for derivatives, the next n*n elements for regular ERIs + opt->q_cond = (double *)malloc(sizeof(double) * nbas*nbas*2); + + if (ao_loc[nbas] == CINTtot_cgto_spheric(bas, nbas)) { + CVHFnr_int2e_q_cond(int2e_sph, NULL, opt->q_cond+Nbas2, ao_loc, + atm, natm, bas, nbas, env); + } else { + CVHFnr_int2e_q_cond(int2e_cart, NULL, opt->q_cond+Nbas2, ao_loc, + atm, natm, bas, nbas, env); + } + CVHFnr_int2e_pppp_q_cond(intor, cintopt, opt->q_cond, ao_loc, + atm, natm, bas, nbas, env); +} + void CVHFipip1_direct_scf_dm(CVHFOpt *opt, double *dm, int nset, int *ao_loc, int *atm, int natm, int *bas, int nbas, double *env) { diff --git a/pyscf/lib/vhf/nr_sgx_direct.c b/pyscf/lib/vhf/nr_sgx_direct.c index f56a1b2314..4ae821dde2 100644 --- a/pyscf/lib/vhf/nr_sgx_direct.c +++ b/pyscf/lib/vhf/nr_sgx_direct.c @@ -254,18 +254,10 @@ void SGXnr_direct_drv(int (*intor)(), void (*fdot)(), SGXJKOperator **jkop, } } - -void SGXsetnr_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, - int *ao_loc, int *atm, int natm, - int *bas, int nbas, double *env) +void SGXnr_q_cond(int (*intor)(), CINTOpt *cintopt, double *q_cond, + int *ao_loc, int *atm, int natm, + int *bas, int nbas, double *env) { - if (opt->q_cond != NULL) { - free(opt->q_cond); - } - nbas = opt->nbas; - double *q_cond = (double *)malloc(sizeof(double) * nbas*nbas); - opt->q_cond = q_cond; - int shls_slice[] = {0, nbas}; int cache_size = GTOmax_cache_size(intor, shls_slice, 1, atm, natm, bas, nbas, env); @@ -316,21 +308,24 @@ void SGXsetnr_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, } } -void SGXsetnr_direct_scf_dm(CVHFOpt *opt, double *dm, int nset, int *ao_loc, - int *atm, int natm, int *bas, int nbas, double *env, - int ngrids) +void SGXsetnr_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, + int *ao_loc, int *atm, int natm, + int *bas, int nbas, double *env) { - nbas = opt->nbas; - if (opt->dm_cond != NULL) { - free(opt->dm_cond); + if (opt->q_cond != NULL) { + free(opt->q_cond); } - opt->dm_cond = (double *)malloc(sizeof(double) * nbas*ngrids); - // nbas in the input arguments may different to opt->nbas. - // Use opt->nbas because it is used in the prescreen function - memset(opt->dm_cond, 0, sizeof(double)*nbas*ngrids); - opt->ngrids = ngrids; + nbas = opt->nbas; + double *q_cond = (double *)malloc(sizeof(double) * nbas*nbas); + opt->q_cond = q_cond; + SGXnr_q_cond(intor, cintopt, q_cond, ao_loc, atm, natm, bas, nbas, env); +} - const size_t nao = ao_loc[nbas] - ao_loc[0]; +void SGXnr_dm_cond(double *dm_cond, double *dm, int nset, int *ao_loc, + int *atm, int natm, int *bas, int nbas, double *env, + int ngrids) +{ + size_t nao = ao_loc[nbas] - ao_loc[0]; double dmax; size_t i, j, jsh, iset; double *pdm; @@ -343,10 +338,28 @@ void SGXsetnr_direct_scf_dm(CVHFOpt *opt, double *dm, int nset, int *ao_loc, dmax = MAX(dmax, fabs(pdm[i*nao+j])); } } - opt->dm_cond[jsh*ngrids+i] = dmax; + dm_cond[jsh*ngrids+i] = dmax; } } } +void SGXsetnr_direct_scf_dm(CVHFOpt *opt, double *dm, int nset, int *ao_loc, + int *atm, int natm, int *bas, int nbas, double *env, + int ngrids) +{ + nbas = opt->nbas; + if (opt->dm_cond != NULL) { + free(opt->dm_cond); + } + opt->dm_cond = (double *)malloc(sizeof(double) * nbas*ngrids); + // nbas in the input arguments may different to opt->nbas. + // Use opt->nbas because it is used in the prescreen function + memset(opt->dm_cond, 0, sizeof(double)*nbas*ngrids); + opt->ngrids = ngrids; + + SGXnr_dm_cond(opt->dm_cond, dm, nset, ao_loc, + atm, natm, bas, nbas, env, ngrids); +} + int SGXnr_ovlp_prescreen(int *shls, CVHFOpt *opt, int *atm, int *bas, double *env) { diff --git a/pyscf/lib/vhf/nr_sr_vhf.c b/pyscf/lib/vhf/nr_sr_vhf.c index 742b768a37..e0e5e08838 100644 --- a/pyscf/lib/vhf/nr_sr_vhf.c +++ b/pyscf/lib/vhf/nr_sr_vhf.c @@ -840,9 +840,9 @@ void CVHFnr_sr_direct_drv(int (*intor)(), void (*fdot)(), JKOperator **jkop, // sqrt(-log(1e-9)) #define R_GUESS_FAC 4.5f -void CVHFsetnr_sr_direct_scf(int (*intor)(), CINTOpt *cintopt, float *q_cond, - int *ao_loc, int *atm, int natm, - int *bas, int nbas, double *env) +void CVHFnr_sr_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, float *q_cond, + int *ao_loc, int *atm, int natm, + int *bas, int nbas, double *env) { size_t Nbas = nbas; size_t Nbas2 = Nbas * Nbas; @@ -994,3 +994,11 @@ void CVHFsetnr_sr_direct_scf(int (*intor)(), CINTOpt *cintopt, float *q_cond, } free(exps); } + +void CVHFsetnr_sr_direct_scf(int (*intor)(), CINTOpt *cintopt, float *q_cond, + int *ao_loc, int *atm, int natm, + int *bas, int nbas, double *env) +{ + CVHFnr_sr_int2e_q_cond(intor, cintopt, q_cond, ao_loc, + atm, natm, bas, nbas, env); +} diff --git a/pyscf/lib/vhf/optimizer.c b/pyscf/lib/vhf/optimizer.c index 7833da36f8..9c339a5ac1 100644 --- a/pyscf/lib/vhf/optimizer.c +++ b/pyscf/lib/vhf/optimizer.c @@ -464,6 +464,28 @@ void CVHFset_q_cond(CVHFOpt *opt, double *q_cond, int len) NPdcopy(opt->q_cond, q_cond, len); } +void CVHFnr_dm_cond1(double *dm_cond, double *dm, int nset, int *ao_loc, + int *atm, int natm, int *bas, int nbas, double *env) +{ + size_t nao = ao_loc[nbas]; + double dmax; + int i, j, ish, jsh; + int iset; + double *pdm; + for (ish = 0; ish < nbas; ish++) { + for (jsh = 0; jsh < nbas; jsh++) { + dmax = 0; + for (iset = 0; iset < nset; iset++) { + pdm = dm + nao*nao*iset; + for (i = ao_loc[ish]; i < ao_loc[ish+1]; i++) { + for (j = ao_loc[jsh]; j < ao_loc[jsh+1]; j++) { + dmax = MAX(dmax, fabs(pdm[i*nao+j])); + } } + } + dm_cond[ish*nbas+jsh] = dmax; + } } +} + void CVHFnr_dm_cond(double *dm_cond, double *dm, int nset, int *ao_loc, int *atm, int natm, int *bas, int nbas, double *env) { diff --git a/pyscf/lib/vhf/optimizer.h b/pyscf/lib/vhf/optimizer.h index 08d140991f..310a21a73a 100644 --- a/pyscf/lib/vhf/optimizer.h +++ b/pyscf/lib/vhf/optimizer.h @@ -65,3 +65,11 @@ void CVHFnr_optimizer(CVHFOpt **vhfopt, int (*intor)(), CINTOpt *cintopt, void CVHFset_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, double *q_cond, int *ao_loc, int *atm, int natm, int *bas, int nbas, double *env); + +void CVHFnr_int2e_q_cond(int (*intor)(), CINTOpt *cintopt, double *q_cond, + int *ao_loc, int *atm, int natm, + int *bas, int nbas, double *env); +void CVHFnr_dm_cond1(double *dm_cond, double *dm, int nset, int *ao_loc, + int *atm, int natm, int *bas, int nbas, double *env); +void CVHFnr_dm_cond(double *dm_cond, double *dm, int nset, int *ao_loc, + int *atm, int natm, int *bas, int nbas, double *env); diff --git a/pyscf/lib/vhf/rkb_screen.c b/pyscf/lib/vhf/rkb_screen.c index 7aa5ad458c..4ff0063140 100644 --- a/pyscf/lib/vhf/rkb_screen.c +++ b/pyscf/lib/vhf/rkb_screen.c @@ -73,13 +73,18 @@ int CVHFrkbllll_vkscreen(int *shls, CVHFOpt *opt, int k = shls[2]; int l = shls[3]; int nbas = opt->nbas; - int idm; double qijkl = opt->q_cond[i*nbas+j] * opt->q_cond[k*nbas+l]; - double *pdmscond = opt->dm_cond + nbas*nbas; - for (idm = 0; idm < (n_dm+1)/2; idm++) { + if (n_dm <= 2) { + int nbas2 = nbas * nbas; + double *pdmscond = opt->dm_cond + nbas2; // note in _vhf.rdirect_mapdm, J and K share the same DM - dms_cond[idm*2+0] = pdmscond + idm*nbas*nbas; // for vj - dms_cond[idm*2+1] = pdmscond + idm*nbas*nbas; // for vk + dms_cond[0] = pdmscond; // for vj + dms_cond[1] = pdmscond; // for vk + } else { + int idm; + for (idm = 0; idm < n_dm; idm++) { + dms_cond[idm] = opt->dm_cond; + } } *dm_atleast = opt->direct_scf_cutoff / qijkl; return 1; @@ -125,26 +130,27 @@ int CVHFrkbssll_vkscreen(int *shls, CVHFOpt *opt, int k = shls[2]; int l = shls[3]; int nbas = opt->nbas; + int nbas2 = nbas * nbas; int idm; - double qijkl = opt->q_cond[nbas*nbas*SS+i*nbas+j] * opt->q_cond[k*nbas+l]; - double *pdmscond = opt->dm_cond + 4*nbas*nbas; + double qijkl = opt->q_cond[nbas2*SS+i*nbas+j] * opt->q_cond[k*nbas+l]; + double *dm_cond = opt->dm_cond; int nset = (n_dm+2) / 3; - double *dmscondll = pdmscond + nset*nbas*nbas*LL; - double *dmscondss = pdmscond + nset*nbas*nbas*SS; - double *dmscondsl = pdmscond + nset*nbas*nbas*SL; + double *dmscondll = dm_cond + (1+nset)*nbas2*LL + nbas2; + double *dmscondss = dm_cond + (1+nset)*nbas2*SS + nbas2; + double *dmscondsl = dm_cond + (1+nset)*nbas2*SL + nbas2; for (idm = 0; idm < nset; idm++) { - dms_cond[nset*0+idm] = dmscondll + idm*nbas*nbas; - dms_cond[nset*1+idm] = dmscondss + idm*nbas*nbas; - dms_cond[nset*2+idm] = dmscondsl + idm*nbas*nbas; + dms_cond[nset*0+idm] = dmscondll + idm*nbas2; + dms_cond[nset*1+idm] = dmscondss + idm*nbas2; + dms_cond[nset*2+idm] = dmscondsl + idm*nbas2; } *dm_atleast = opt->direct_scf_cutoff / qijkl; return 1; } -static void set_qcond(int (*intor)(), CINTOpt *cintopt, double *qcond, - int *ao_loc, int *atm, int natm, - int *bas, int nbas, double *env) +void CVHFrkb_q_cond(int (*intor)(), CINTOpt *cintopt, double *qcond, + int *ao_loc, int *atm, int natm, + int *bas, int nbas, double *env) { int shls_slice[] = {0, nbas}; const int cache_size = GTOmax_cache_size(intor, shls_slice, 1, @@ -198,7 +204,7 @@ void CVHFrkbllll_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, opt->q_cond = (double *)malloc(sizeof(double) * nbas*nbas); assert(intor == &int2e_spinor); - set_qcond(intor, cintopt, opt->q_cond, ao_loc, atm, natm, bas, nbas, env); + CVHFrkb_q_cond(intor, cintopt, opt->q_cond, ao_loc, atm, natm, bas, nbas, env); } void CVHFrkbssss_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, @@ -211,7 +217,7 @@ void CVHFrkbssss_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, opt->q_cond = (double *)malloc(sizeof(double) * nbas*nbas); assert(intor == &int2e_spsp1spsp2_spinor); - set_qcond(intor, cintopt, opt->q_cond, ao_loc, atm, natm, bas, nbas, env); + CVHFrkb_q_cond(intor, cintopt, opt->q_cond, ao_loc, atm, natm, bas, nbas, env); } @@ -224,16 +230,16 @@ void CVHFrkbssll_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt, } opt->q_cond = (double *)malloc(sizeof(double) * nbas*nbas*2); - set_qcond(&int2e_spinor, NULL, opt->q_cond, ao_loc, atm, natm, bas, nbas, env); - set_qcond(&int2e_spsp1spsp2_spinor, NULL, opt->q_cond+nbas*nbas, ao_loc, + CVHFrkb_q_cond(&int2e_spinor, NULL, opt->q_cond, ao_loc, atm, natm, bas, nbas, env); + CVHFrkb_q_cond(&int2e_spsp1spsp2_spinor, NULL, opt->q_cond+nbas*nbas, ao_loc, atm, natm, bas, nbas, env); } -static void set_dmcond(double *dmcond, double *dmscond, double complex *dm, - double direct_scf_cutoff, int nset, int *ao_loc, - int *atm, int natm, int *bas, int nbas, double *env) +void CVHFrkb_dm_cond(double *dmcond, double complex *dm, int nset, int *ao_loc, + int *atm, int natm, int *bas, int nbas, double *env) { - const size_t nao = ao_loc[nbas]; + double *dmscond = dmcond + nbas * nbas; + size_t nao = ao_loc[nbas]; double dmax, dmaxi, tmp; int i, j, ish, jsh; int iset; @@ -270,8 +276,7 @@ void CVHFrkbllll_direct_scf_dm(CVHFOpt *opt, double complex *dm, int nset, opt->dm_cond = (double *)malloc(sizeof(double)*nbas*nbas*(1+nset)); NPdset0(opt->dm_cond, ((size_t)nbas)*nbas*(1+nset)); // dmcond followed by dmscond which are max matrix element for each dm - set_dmcond(opt->dm_cond, opt->dm_cond+nbas*nbas, dm, - opt->direct_scf_cutoff, nset, ao_loc, atm, natm, bas, nbas, env); + CVHFrkb_dm_cond(opt->dm_cond, dm, nset, ao_loc, atm, natm, bas, nbas, env); } void CVHFrkbssss_direct_scf_dm(CVHFOpt *opt, double complex *dm, int nset, @@ -283,53 +288,32 @@ void CVHFrkbssss_direct_scf_dm(CVHFOpt *opt, double complex *dm, int nset, } opt->dm_cond = (double *)malloc(sizeof(double)*nbas*nbas*(1+nset)); NPdset0(opt->dm_cond, ((size_t)nbas)*nbas*(1+nset)); - set_dmcond(opt->dm_cond, opt->dm_cond+nbas*nbas, dm, - opt->direct_scf_cutoff, nset, ao_loc, atm, natm, bas, nbas, env); + CVHFrkb_dm_cond(opt->dm_cond, dm, nset, ao_loc, atm, natm, bas, nbas, env); } // the current order of dmscond (dmll, dmss, dmsl) is consistent to the // function _call_veff_ssll in dhf.py -void CVHFrkbssll_direct_scf_dm(CVHFOpt *opt, double complex *dm, int nset, - int *ao_loc, int *atm, int natm, - int *bas, int nbas, double *env) +void CVHFrkbssll_dm_cond(double *dm_cond, double complex *dm, int nset, int *ao_loc, + int *atm, int natm, int *bas, int nbas, double *env) { - if (opt->dm_cond != NULL) { - free(opt->dm_cond); - } - if (nset < 4) { - fprintf(stderr, "4 sets of DMs (dmll,dmss,dmsl,dmls) are " - "required to set rkb prescreening\n"); - exit(1); - } nset = nset / 4; int n2c = CINTtot_cgto_spinor(bas, nbas); size_t nbas2 = nbas * nbas; - opt->dm_cond = (double *)malloc(sizeof(double)*nbas2*4*(1+nset)); - NPdset0(opt->dm_cond, nbas2*4*(1+nset)); - - // 4 types of dmcond (LL,SS,SL,LS) followed by 4 types of dmscond - double *dmcondll = opt->dm_cond + nbas2*LL; - double *dmcondss = opt->dm_cond + nbas2*SS; - double *dmcondsl = opt->dm_cond + nbas2*SL; - double *dmcondls = opt->dm_cond + nbas2*LS; - double *pdmscond = opt->dm_cond + nbas2*4; - double *dmscondll = pdmscond + nset*nbas2*LL; - double *dmscondss = pdmscond + nset*nbas2*SS; - double *dmscondsl = pdmscond + nset*nbas2*SL; - double *dmscondls = pdmscond + nset*nbas2*LS; + double *dmcondll = dm_cond + (1+nset)*nbas2*LL; + double *dmcondss = dm_cond + (1+nset)*nbas2*SS; + double *dmcondsl = dm_cond + (1+nset)*nbas2*SL; + double *dmcondls = dm_cond + (1+nset)*nbas2*LS; + double *dmscondls = dmcondls + nbas2; + double *dmscondsl = dmcondsl + nbas2; double complex *dmll = dm + n2c*n2c*LL*nset; double complex *dmss = dm + n2c*n2c*SS*nset; double complex *dmsl = dm + n2c*n2c*SL*nset; double complex *dmls = dm + n2c*n2c*LS*nset; - set_dmcond(dmcondll, dmscondll, dmll, - opt->direct_scf_cutoff, nset, ao_loc, atm, natm, bas, nbas, env); - set_dmcond(dmcondss, dmscondss, dmss, - opt->direct_scf_cutoff, nset, ao_loc, atm, natm, bas, nbas, env); - set_dmcond(dmcondsl, dmscondsl, dmsl, - opt->direct_scf_cutoff, nset, ao_loc, atm, natm, bas, nbas, env); - set_dmcond(dmcondls, dmscondls, dmls, - opt->direct_scf_cutoff, nset, ao_loc, atm, natm, bas, nbas, env); + CVHFrkb_dm_cond(dmcondll, dmll, nset, ao_loc, atm, natm, bas, nbas, env); + CVHFrkb_dm_cond(dmcondss, dmss, nset, ao_loc, atm, natm, bas, nbas, env); + CVHFrkb_dm_cond(dmcondsl, dmsl, nset, ao_loc, atm, natm, bas, nbas, env); + CVHFrkb_dm_cond(dmcondls, dmls, nset, ao_loc, atm, natm, bas, nbas, env); // aggregate dmcondls to dmcondsl int i, j, n; @@ -346,3 +330,24 @@ void CVHFrkbssll_direct_scf_dm(CVHFOpt *opt, double complex *dm, int nset, dmscondls += nbas2; } } + +// the current order of dmscond (dmll, dmss, dmsl) is consistent to the +// function _call_veff_ssll in dhf.py +void CVHFrkbssll_direct_scf_dm(CVHFOpt *opt, double complex *dm, int nset, + int *ao_loc, int *atm, int natm, + int *bas, int nbas, double *env) +{ + if (opt->dm_cond != NULL) { + free(opt->dm_cond); + } + if (nset < 4) { + fprintf(stderr, "4 sets of DMs (dmll,dmss,dmsl,dmls) are " + "required to set rkb prescreening\n"); + exit(1); + } + nset = nset / 4; + size_t nbas2 = nbas * nbas; + opt->dm_cond = (double *)malloc(sizeof(double)*nbas2*4*(1+nset)); + CVHFrkbssll_dm_cond(opt->dm_cond, dm, nset, ao_loc, + atm, natm, bas, nbas, env); +} diff --git a/pyscf/pbc/scf/rsjk.py b/pyscf/pbc/scf/rsjk.py index eec9d8c1ca..64262b464e 100644 --- a/pyscf/pbc/scf/rsjk.py +++ b/pyscf/pbc/scf/rsjk.py @@ -205,14 +205,14 @@ def build(self, omega=None, intor='int2e'): qindex = np.empty((3,nbas,nbas), dtype=np.int16) ao_loc = supmol_sr.ao_loc with supmol_sr.with_integral_screen(self.direct_scf_tol**2): - libpbc.PBCVHFsetnr_direct_scf( + libpbc.PBCVHFnr_int2e_q_cond( libpbc.int2e_sph, self._cintopt, qindex.ctypes.data_as(ctypes.c_void_p), ao_loc.ctypes.data_as(ctypes.c_void_p), supmol_sr._atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(supmol_sr.natm), supmol_sr._bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(supmol_sr.nbas), supmol_sr._env.ctypes.data_as(ctypes.c_void_p)) - libpbc.PBCVHFsetnr_sindex( + libpbc.PBCVHFnr_sindex( qindex[2:].ctypes.data_as(ctypes.c_void_p), supmol_sr._atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(supmol_sr.natm), supmol_sr._bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(supmol_sr.nbas), diff --git a/pyscf/scf/_vhf.py b/pyscf/scf/_vhf.py index 7f7c9efaed..bd39196577 100644 --- a/pyscf/scf/_vhf.py +++ b/pyscf/scf/_vhf.py @@ -147,9 +147,10 @@ def get_dm_cond(self, shape=None): return numpy.ctypeslib.as_array(data, shape=shape) dm_cond = property(get_dm_cond) +# TODO: replace VHFOpt in future release class _VHFOpt: - def __init__(self, mol, intor=None, - prescreen='CVHFnoscreen', qcondname=None, dmcondname=None): + def __init__(self, mol, intor=None, prescreen='CVHFnoscreen', + qcondname=None, dmcondname=None, direct_scf_tol=1e-14): '''New version of VHFOpt (under development). If function "qcondname" is presented, the qcond (sqrt(integrals)) @@ -159,19 +160,20 @@ def __init__(self, mol, intor=None, names of C functions defined in libcvhf module ''' self.mol = mol - self._q_cond = None - self._dm_cond = None self._this = cvhfopt = _CVHFOpt() cvhfopt.nbas = mol.nbas - cvhfopt.direct_scf_cutoff = 1e-14 + cvhfopt.direct_scf_cutoff = direct_scf_tol cvhfopt.fprescreen = _fpointer(prescreen) cvhfopt.r_vkscreen = _fpointer('CVHFr_vknoscreen') + self._q_cond = None + self._dm_cond = None if intor is None: self._intor = intor self._cintopt = lib.c_null_ptr() else: - self._intor = mol._add_suffix(intor) + intor = mol._add_suffix(intor) + self._intor = intor self._cintopt = make_cintopt(mol._atm, mol._bas, mol._env, intor) self._dmcondname = dmcondname @@ -184,19 +186,23 @@ def init_cvhf_direct(self, mol, intor, qcondname): defined in libcvhf module ''' intor = mol._add_suffix(intor) - assert intor == self._intor - cintopt = self._cintopt - ao_loc = mol.ao_loc_nr() + if intor == self._intor: + cintopt = self._cintopt + else: + cintopt = lib.c_null_ptr() + ao_loc = make_loc(mol._bas, intor) if isinstance(qcondname, ctypes._CFuncPtr): fqcond = qcondname else: fqcond = getattr(libcvhf, qcondname) nbas = mol.nbas - q_cond = self._q_cond = numpy.empty((nbas, nbas)) - fqcond(getattr(libcvhf, intor), cintopt, q_cond.ctypes, - ao_loc.ctypes, mol._atm.ctypes, ctypes.c_int(mol.natm), - mol._bas.ctypes, ctypes.c_int(nbas), mol._env.ctypes) - self._this.q_cond = q_cond.ctypes.data_as(ctypes.c_void_p) + q_cond = numpy.empty((nbas, nbas)) + with mol.with_integral_screen(self.direct_scf_tol**2): + fqcond(getattr(libcvhf, intor), cintopt, q_cond.ctypes, + ao_loc.ctypes, mol._atm.ctypes, ctypes.c_int(mol.natm), + mol._bas.ctypes, ctypes.c_int(nbas), mol._env.ctypes) + + self.q_cond = q_cond self._qcondname = qcondname @property @@ -225,7 +231,7 @@ def set_dm(self, dm, atm, bas, env): else: n_dm = len(dm) dm = numpy.asarray(dm, order='C') - ao_loc = mol.ao_loc_nr() + ao_loc = make_loc(mol._bas, self._intor) if isinstance(self._dmcondname, ctypes._CFuncPtr): fdmcond = self._dmcondname else: @@ -235,71 +241,58 @@ def set_dm(self, dm, atm, bas, env): fdmcond(dm_cond.ctypes, dm.ctypes, ctypes.c_int(n_dm), ao_loc.ctypes, mol._atm.ctypes, ctypes.c_int(mol.natm), mol._bas.ctypes, ctypes.c_int(nbas), mol._env.ctypes) - self._dm_cond = dm_cond - self._this.dm_cond = dm_cond.ctypes.data_as(ctypes.c_void_p) + self.dm_cond = dm_cond - def get_q_cond(self, shape=None): - '''Return an array associated to q_cond. Contents of q_cond can be - modified through this array - ''' + def get_q_cond(self): return self._q_cond q_cond = property(get_q_cond) - def get_dm_cond(self, shape=None): - '''Return an array associated to dm_cond. Contents of dm_cond can be - modified through this array - ''' + @q_cond.setter + def q_cond(self, q_cond): + self._q_cond = q_cond + if q_cond is not None: + self._this.q_cond = q_cond.ctypes.data_as(ctypes.c_void_p) + + def get_dm_cond(self): return self._dm_cond dm_cond = property(get_dm_cond) - -class SGXOpt(VHFOpt): - def __init__(self, mol, intor=None, - prescreen='CVHFnoscreen', qcondname=None, dmcondname=None): - super(SGXOpt, self).__init__(mol, intor, prescreen, qcondname, dmcondname) + @dm_cond.setter + def dm_cond(self, dm_cond): + self._dm_cond = dm_cond + if dm_cond is not None: + self._this.dm_cond = dm_cond.ctypes.data_as(ctypes.c_void_p) + +class SGXOpt(_VHFOpt): + def __init__(self, mol, intor=None, prescreen='CVHFnoscreen', + qcondname=None, dmcondname=None, direct_scf_cutoff=1e-14): + _VHFOpt.__init__(self, mol, intor, prescreen, qcondname, dmcondname, + direct_scf_cutoff) self.ngrids = None def set_dm(self, dm, atm, bas, env): - if self._dmcondname is not None: - c_atm = numpy.asarray(atm, dtype=numpy.int32, order='C') - c_bas = numpy.asarray(bas, dtype=numpy.int32, order='C') - c_env = numpy.asarray(env, dtype=numpy.double, order='C') - natm = ctypes.c_int(c_atm.shape[0]) - nbas = ctypes.c_int(c_bas.shape[0]) - if isinstance(dm, numpy.ndarray) and dm.ndim == 2: - n_dm = 1 - ngrids = dm.shape[0] - else: - n_dm = len(dm) - ngrids = dm.shape[1] - dm = numpy.asarray(dm, order='C') - ao_loc = make_loc(c_bas, self._intor) - if isinstance(self._dmcondname, ctypes._CFuncPtr): - fsetdm = self._dmcondname - else: - fsetdm = getattr(libcvhf, self._dmcondname) - if self._dmcondname == 'SGXsetnr_direct_scf_dm': - fsetdm(self._this, - dm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(n_dm), - ao_loc.ctypes.data_as(ctypes.c_void_p), - c_atm.ctypes.data_as(ctypes.c_void_p), natm, - c_bas.ctypes.data_as(ctypes.c_void_p), nbas, - c_env.ctypes.data_as(ctypes.c_void_p), - ctypes.c_int(ngrids)) - self.ngrids = ngrids - else: - raise ValueError('Can only use SGX dm screening for SGXOpt') + if self._dmcondname is None: + return - def get_dm_cond(self): - '''Return an array associated to dm_cond. Contents of dm_cond can be - modified through this array - ''' - nbas = self._this.contents.nbas - shape = (nbas, self.ngrids) - data = ctypes.cast(self._this.contents.dm_cond, - ctypes.POINTER(ctypes.c_double)) - return numpy.ctypeslib.as_array(data, shape=shape) - dm_cond = property(get_dm_cond) + mol = self.mol + if isinstance(dm, numpy.ndarray) and dm.ndim == 2: + n_dm = 1 + else: + n_dm = len(dm) + dm = numpy.asarray(dm, order='C') + ao_loc = make_loc(mol._bas, self._intor) + if isinstance(self._dmcondname, ctypes._CFuncPtr): + fdmcond = self._dmcondname + else: + if self._dmcondname != 'SGXnr_dm_cond': + raise ValueError('SGXOpt only supports SGXnr_dm_cond') + fdmcond = getattr(libcvhf, self._dmcondname) + dm_cond = numpy.empty((mol.nbas, self.ngrids)) + fdmcond(dm_cond.ctypes, dm.ctypes, ctypes.c_int(n_dm), + ao_loc.ctypes, mol._atm.ctypes, ctypes.c_int(mol.natm), + mol._bas.ctypes, ctypes.c_int(mol.nbas), mol._env.ctypes, + ctypes.c_int(self.ngrids)) + self.dm_cond = dm_cond class _CVHFOpt(ctypes.Structure): @@ -477,6 +470,15 @@ def direct_mapdm(intor, aosym, jkdescript, single_dm = True else: single_dm = False + intor = ascint3(intor) + if vhfopt is None: + cvhfopt = lib.c_null_ptr() + cintopt = None + else: + vhfopt.set_dm(dms, atm, bas, env) + cvhfopt = vhfopt._this + cintopt = vhfopt._cintopt + n_dm = len(dms) dms = [numpy.asarray(dm, order='C', dtype=numpy.double) for dm in dms] if isinstance(jkdescript, str): @@ -490,15 +492,6 @@ def direct_mapdm(intor, aosym, jkdescript, # make n_dm copies for each jk script jkscripts = numpy.repeat(jkscripts, n_dm) - intor = ascint3(intor) - if vhfopt is None: - cvhfopt = lib.c_null_ptr() - cintopt = None - else: - vhfopt.set_dm(dms, atm, bas, env) - cvhfopt = vhfopt._this - cintopt = vhfopt._cintopt - vjk = nr_direct_drv(intor, aosym, jkscripts, dms, ncomp, atm, bas, env, cvhfopt, cintopt, shls_slice, shls_excludes, out, optimize_sr=optimize_sr) @@ -602,7 +595,7 @@ def nr_direct_drv(intor, aosym, jkscript, buf = numpy.empty(vshape) else: buf = out[i] - assert buf.shape == vshape + assert buf.size == numpy.prod(vshape) assert buf.dtype == numpy.double assert buf.flags.c_contiguous vjk.append(buf) @@ -678,7 +671,10 @@ def rdirect_mapdm(intor, aosym, jkdescript, cvhfopt = lib.c_null_ptr() else: vhfopt.set_dm(dms, atm, bas, env) - cvhfopt = vhfopt._this + if isinstance(vhfopt, _VHFOpt): + cvhfopt = ctypes.byref(vhfopt._this) + else: + cvhfopt = vhfopt._this cintopt = vhfopt._cintopt cintor = getattr(libcvhf, vhfopt._intor) if cintopt is None: @@ -750,7 +746,10 @@ def rdirect_bindm(intor, aosym, jkdescript, cvhfopt = lib.c_null_ptr() else: vhfopt.set_dm(dms, atm, bas, env) - cvhfopt = vhfopt._this + if isinstance(vhfopt, _VHFOpt): + cvhfopt = ctypes.byref(vhfopt._this) + else: + cvhfopt = vhfopt._this cintopt = vhfopt._cintopt cintor = getattr(libcvhf, vhfopt._intor) if cintopt is None: diff --git a/pyscf/scf/dhf.py b/pyscf/scf/dhf.py index ac950fed35..54cfcc6c72 100644 --- a/pyscf/scf/dhf.py +++ b/pyscf/scf/dhf.py @@ -22,6 +22,7 @@ from functools import reduce +import ctypes import numpy from pyscf import lib from pyscf import gto @@ -554,38 +555,32 @@ def get_occ(self, mo_energy=None, mo_coeff=None): def init_direct_scf(self, mol=None): if mol is None: mol = self.mol def set_vkscreen(opt, name): - opt._this.contents.r_vkscreen = _vhf._fpointer(name) + opt._this.r_vkscreen = _vhf._fpointer(name) cpu0 = (logger.process_clock(), logger.perf_counter()) - with mol.with_integral_screen(self.direct_scf_tol**2): - opt_llll = _vhf.VHFOpt(mol, 'int2e_spinor', 'CVHFrkbllll_prescreen', - 'CVHFrkbllll_direct_scf', - 'CVHFrkbllll_direct_scf_dm') - opt_llll.direct_scf_tol = self.direct_scf_tol - set_vkscreen(opt_llll, 'CVHFrkbllll_vkscreen') - opt_ssss = _vhf.VHFOpt(mol, 'int2e_spsp1spsp2_spinor', - 'CVHFrkbllll_prescreen', - 'CVHFrkbssss_direct_scf', - 'CVHFrkbssss_direct_scf_dm') - c1 = .5 / lib.param.LIGHT_SPEED - q_cond = opt_ssss.get_q_cond() - q_cond *= c1**2 - opt_ssss.direct_scf_tol = self.direct_scf_tol - set_vkscreen(opt_ssss, 'CVHFrkbllll_vkscreen') - opt_ssll = _vhf.VHFOpt(mol, 'int2e_spsp1_spinor', - 'CVHFrkbssll_prescreen', - 'CVHFrkbssll_direct_scf', - 'CVHFrkbssll_direct_scf_dm') - opt_ssll.direct_scf_tol = self.direct_scf_tol - set_vkscreen(opt_ssll, 'CVHFrkbssll_vkscreen') - nbas = mol.nbas - # The second parts of q_cond corresponds to ssss integrals. They - # need to be scaled by the factor (1/2c)^2 - q_cond = opt_ssll.get_q_cond(shape=(2, nbas, nbas)) - q_cond[1] *= c1**2 - -#TODO: prescreen for gaunt - opt_gaunt = None + opt_llll = _VHFOpt(mol, 'int2e_spinor', 'CVHFrkbllll_prescreen', + 'CVHFrkb_q_cond', 'CVHFrkb_dm_cond', + direct_scf_tol=self.direct_scf_tol) + set_vkscreen(opt_llll, 'CVHFrkbllll_vkscreen') + + c1 = .5 / lib.param.LIGHT_SPEED + opt_ssss = _VHFOpt(mol, 'int2e_spsp1spsp2_spinor', + 'CVHFrkbllll_prescreen', 'CVHFrkb_q_cond', + 'CVHFrkb_dm_cond', + direct_scf_tol=self.direct_scf_tol/c1**4) + opt_ssss.direct_scf_tol = self.direct_scf_tol + opt_ssss.q_cond *= c1**2 + set_vkscreen(opt_ssss, 'CVHFrkbllll_vkscreen') + + opt_ssll = _VHFOpt(mol, 'int2e_spsp1_spinor', + 'CVHFrkbssll_prescreen', + dmcondname='CVHFrkbssll_dm_cond', + direct_scf_tol=self.direct_scf_tol) + opt_ssll.q_cond = numpy.array([opt_llll.q_cond, opt_ssss.q_cond]) + set_vkscreen(opt_ssll, 'CVHFrkbssll_vkscreen') + + #TODO: prescreen for gaunt + opt_gaunt = None logger.timer(self, 'init_direct_scf', *cpu0) return opt_llll, opt_ssll, opt_ssss, opt_gaunt @@ -989,6 +984,29 @@ def _proj_dmll(mol_nr, dm_nr, mol): dm[:n2c,:n2c] = (dm_ll + time_reversal_matrix(mol, dm_ll)) * .5 return dm +class _VHFOpt(_vhf._VHFOpt): + def set_dm(self, dm, atm, bas, env): + if self._dmcondname is None: + return + + mol = self.mol + if isinstance(dm, numpy.ndarray) and dm.ndim == 2: + n_dm = 1 + else: + n_dm = len(dm) + dm = numpy.asarray(dm, order='C') + ao_loc = mol.ao_loc_2c() + if isinstance(self._dmcondname, ctypes._CFuncPtr): + fdmcond = self._dmcondname + else: + fdmcond = getattr(_vhf.libcvhf, self._dmcondname) + nbas = mol.nbas + dm_cond = numpy.empty((n_dm*2, nbas, nbas)) + fdmcond(dm_cond.ctypes, dm.ctypes, ctypes.c_int(n_dm), + ao_loc.ctypes, mol._atm.ctypes, ctypes.c_int(mol.natm), + mol._bas.ctypes, ctypes.c_int(nbas), mol._env.ctypes) + self.dm_cond = dm_cond + if __name__ == '__main__': import pyscf.gto diff --git a/pyscf/scf/hf.py b/pyscf/scf/hf.py index 1f157ed402..409956f9ea 100644 --- a/pyscf/scf/hf.py +++ b/pyscf/scf/hf.py @@ -1752,10 +1752,9 @@ def init_direct_scf(self, mol=None): # Integrals < direct_scf_tol may be set to 0 in int2e. # Higher accuracy is required for Schwartz inequality prescreening. cpu0 = (logger.process_clock(), logger.perf_counter()) - with mol.with_integral_screen(self.direct_scf_tol**2): - opt = _vhf._VHFOpt(mol, 'int2e', 'CVHFnrs8_prescreen', - 'CVHFnr_int2e_q_cond', 'CVHFnr_dm_cond') - opt.direct_scf_tol = self.direct_scf_tol + opt = _vhf._VHFOpt(mol, 'int2e', 'CVHFnrs8_prescreen', + 'CVHFnr_int2e_q_cond', 'CVHFnr_dm_cond', + self.direct_scf_tol) logger.timer(self, 'init_direct_scf', *cpu0) return opt diff --git a/pyscf/scf/test/test_vhf.py b/pyscf/scf/test/test_vhf.py index 5bada26b36..0b69af1de3 100644 --- a/pyscf/scf/test/test_vhf.py +++ b/pyscf/scf/test/test_vhf.py @@ -202,12 +202,13 @@ def test_direct_sr_vhf(self): class _VHFOpt(_vhf._VHFOpt): def __init__(self, mol, intor=None, prescreen='CVHFnoscreen', - qcondname=None, dmcondname=None, omega=None): + qcondname=None, dmcondname=None, direct_scf_tol=1e-14, + omega=None): assert omega is not None with mol.with_short_range_coulomb(omega): _vhf._VHFOpt.__init__(self, mol, intor, prescreen, qcondname, dmcondname) self.omega = omega - self._this.direct_scf_cutoff = numpy.log(1e-14) + self._this.direct_scf_cutoff = numpy.log(direct_scf_tol) def init_cvhf_direct(self, mol, intor=None, qcondname=None): nbas = mol.nbas diff --git a/pyscf/sgx/sgx.py b/pyscf/sgx/sgx.py index 98750e2f12..a003c7c50e 100644 --- a/pyscf/sgx/sgx.py +++ b/pyscf/sgx/sgx.py @@ -225,21 +225,17 @@ def method_not_implemented(self, *args, **kwargs): mcscf.casci.CASCI.COSX = sgx_fit -def _make_opt(mol, pjs=False): +def _make_opt(mol, pjs=False, + direct_scf_tol=getattr(__config__, 'scf_hf_SCF_direct_scf_tol', 1e-13)): '''Optimizer to genrate 3-center 2-electron integrals''' - intor = mol._add_suffix('int1e_grids') - cintopt = gto.moleintor.make_cintopt(mol._atm, mol._bas, mol._env, intor) - # intor 'int1e_ovlp' is used by the prescreen method - # 'SGXnr_ovlp_prescreen' only. Not used again in other places. - # It can be released early if pjs: - vhfopt = _vhf.SGXOpt(mol, 'int1e_ovlp', 'SGXnr_ovlp_prescreen', - 'SGXsetnr_direct_scf', 'SGXsetnr_direct_scf_dm') + vhfopt = _vhf.SGXOpt(mol, 'int1e_grids', 'SGXnr_ovlp_prescreen', + dmcondname='SGXnr_dm_cond', + direct_scf_tol=direct_scf_tol) else: - vhfopt = _vhf.VHFOpt(mol, 'int1e_ovlp', 'SGXnr_ovlp_prescreen', - 'SGXsetnr_direct_scf') - vhfopt._intor = intor - vhfopt._cintopt = cintopt + vhfopt = _vhf._VHFOpt(mol, 'int1e_grids', 'SGXnr_ovlp_prescreen', + direct_scf_tol=direct_scf_tol) + vhfopt.init_cvhf_direct(mol, 'int1e_ovlp', 'SGXnr_q_cond') return vhfopt @@ -366,21 +362,3 @@ def get_jk(self, dm, hermi=1, vhfopt=None, with_j=True, with_k=True, else: vj, vk = sgx_jk.get_jk(self, dm, hermi, with_j, with_k, direct_scf_tol) return vj, vk - - -if __name__ == '__main__': - from pyscf import scf - mol = gto.Mole() - mol.build( - atom = [["O" , (0. , 0. , 0.)], - [1 , (0. , -0.757 , 0.587)], - [1 , (0. , 0.757 , 0.587)] ], - basis = 'ccpvdz', - ) - method = sgx_fit(scf.RHF(mol), 'weigend') - energy = method.scf() - print(energy - -76.02673747045691) - - method.with_df.dfj = True - energy = method.scf() - print(energy - -76.02686422219752) diff --git a/pyscf/sgx/sgx_jk.py b/pyscf/sgx/sgx_jk.py index b92f9995c9..090c875e4a 100644 --- a/pyscf/sgx/sgx_jk.py +++ b/pyscf/sgx/sgx_jk.py @@ -260,7 +260,7 @@ def _gen_jk_direct(mol, aosym, with_j, with_k, direct_scf_tol, sgxopt=None, pjs= ''' if sgxopt is None: from pyscf.sgx import sgx - sgxopt = sgx._make_opt(mol, pjs=pjs) + sgxopt = sgx._make_opt(mol, pjs, direct_scf_tol) sgxopt.direct_scf_tol = direct_scf_tol ncomp = 1 @@ -316,7 +316,7 @@ def jk_part(mol, grid_coords, dms, fg, weights): drv(cintor, fdot, fjk, dmsptr, vjkptr, n_dm, ncomp, (ctypes.c_int*4)(*shls_slice), ao_loc.ctypes.data_as(ctypes.c_void_p), - sgxopt._cintopt, sgxopt._this, + sgxopt._cintopt, ctypes.byref(sgxopt._this), atm.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.natm), bas.ctypes.data_as(ctypes.c_void_p), ctypes.c_int(mol.nbas), env.ctypes.data_as(ctypes.c_void_p), @@ -355,45 +355,3 @@ def get_gridss(mol, level=1, gthrd=1e-10): return grids get_jk = get_jk_favorj - - -if __name__ == '__main__': - from pyscf import scf - from pyscf.sgx import sgx - mol = gto.Mole() - mol.build( - verbose = 0, - atom = [["O" , (0. , 0. , 0.)], - [1 , (0. , -0.757 , 0.587)], - [1 , (0. , 0.757 , 0.587)] ], - basis = 'ccpvdz', - ) - dm = scf.RHF(mol).run().make_rdm1() - vjref, vkref = scf.hf.get_jk(mol, dm) - print(numpy.einsum('ij,ji->', vjref, dm)) - print(numpy.einsum('ij,ji->', vkref, dm)) - - sgxobj = sgx.SGX(mol) - sgxobj.grids = get_gridss(mol, 0, 1e-10) - with lib.temporary_env(sgxobj, debug=True): - vj, vk = get_jk_favork(sgxobj, dm) - print(numpy.einsum('ij,ji->', vj, dm)) - print(numpy.einsum('ij,ji->', vk, dm)) - print(abs(vjref-vj).max().max()) - print(abs(vkref-vk).max().max()) - with lib.temporary_env(sgxobj, debug=False): - vj1, vk1 = get_jk_favork(sgxobj, dm) - print(abs(vj - vj1).max()) - print(abs(vk - vk1).max()) - - with lib.temporary_env(sgxobj, debug=True): - vj, vk = get_jk_favorj(sgxobj, dm) - print(numpy.einsum('ij,ji->', vj, dm)) - print(numpy.einsum('ij,ji->', vk, dm)) - print(abs(vjref-vj).max().max()) - print(abs(vkref-vk).max().max()) - - with lib.temporary_env(sgxobj, debug=False): - vj1, vk1 = get_jk_favorj(sgxobj, dm) - print(abs(vj - vj1).max()) - print(abs(vk - vk1).max()) diff --git a/pyscf/sgx/test/test_sgx.py b/pyscf/sgx/test/test_sgx.py index 7a1b546509..307aa9a037 100644 --- a/pyscf/sgx/test/test_sgx.py +++ b/pyscf/sgx/test/test_sgx.py @@ -30,8 +30,19 @@ def test_reset(self): self.assertTrue(mf.mol is mol1) self.assertTrue(mf.with_df.mol is mol1) + def test_sgx_scf(self): + mol = gto.Mole() + mol.build( + atom = [["O" , (0. , 0. , 0.)], + [1 , (0. , -0.757 , 0.587)], + [1 , (0. , 0.757 , 0.587)] ], + basis = 'ccpvdz', + ) + method = sgx.sgx_fit(scf.RHF(mol), 'weigend') + method.with_df.dfj = True + energy = method.scf() + self.assertAlmostEqual(energy, -76.02686422219752, 9) if __name__ == "__main__": print("Full Tests for SGX") unittest.main() - diff --git a/pyscf/sgx/test/test_sgx_jk.py b/pyscf/sgx/test/test_sgx_jk.py index 42000b7a4c..0d38cd3f38 100644 --- a/pyscf/sgx/test/test_sgx_jk.py +++ b/pyscf/sgx/test/test_sgx_jk.py @@ -41,6 +41,7 @@ def test_sgx_jk(self): #dm = dm + dm.T mf = scf.UHF(mol) dm = mf.get_init_guess() + vjref, vkref = scf.hf.get_jk(mol, dm) sgxobj = sgx.SGX(mol) sgxobj.grids = sgx_jk.get_gridss(mol, 0, 1e-10) @@ -53,6 +54,8 @@ def test_sgx_jk(self): vj1, vk1 = sgx_jk.get_jk_favork(sgxobj, dm) self.assertAlmostEqual(abs(vj1-vj).max(), 0, 9) self.assertAlmostEqual(abs(vk1-vk).max(), 0, 9) + self.assertAlmostEqual(abs(vjref-vj).max(), 0, 2) + self.assertAlmostEqual(abs(vkref-vk).max(), 0, 2) with lib.temporary_env(sgxobj, debug=False): vj, vk = sgx_jk.get_jk_favorj(sgxobj, dm) @@ -62,6 +65,8 @@ def test_sgx_jk(self): vj1, vk1 = sgx_jk.get_jk_favorj(sgxobj, dm) self.assertAlmostEqual(abs(vj1-vj).max(), 0, 9) self.assertAlmostEqual(abs(vk1-vk).max(), 0, 9) + self.assertAlmostEqual(abs(vjref-vj).max(), 0, 2) + self.assertAlmostEqual(abs(vkref-vk).max(), 0, 2) def test_dfj(self): mol = gto.Mole() @@ -136,4 +141,3 @@ def test_pjs(self): if __name__ == "__main__": print("Full Tests for sgx_jk") unittest.main() - diff --git a/pyscf/x2c/x2c.py b/pyscf/x2c/x2c.py index 9d55d2a0b6..e53a59b2dd 100644 --- a/pyscf/x2c/x2c.py +++ b/pyscf/x2c/x2c.py @@ -521,13 +521,10 @@ def get_occ(self, mo_energy=None, mo_coeff=None): def init_direct_scf(self, mol=None): if mol is None: mol = self.mol - def set_vkscreen(opt, name): - opt._this.contents.r_vkscreen = _vhf._fpointer(name) - opt = _vhf.VHFOpt(mol, 'int2e_spinor', 'CVHFrkbllll_prescreen', - 'CVHFrkbllll_direct_scf', - 'CVHFrkbllll_direct_scf_dm') - opt.direct_scf_tol = self.direct_scf_tol - set_vkscreen(opt, 'CVHFrkbllll_vkscreen') + opt = dhf._VHFOpt(mol, 'int2e_spinor', 'CVHFrkbllll_prescreen', + 'CVHFrkb_q_cond', 'CVHFrkb_dm_cond', + direct_scf_tol=self.direct_scf_tol) + opt._this.r_vkscreen = _vhf._fpointer('CVHFrkbllll_vkscreen') return opt def get_jk(self, mol=None, dm=None, hermi=1, with_j=True, with_k=True,