Skip to content

Commit

Permalink
Optimize auxbasis_response in df-hessian
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Jan 4, 2024
1 parent 1e56c74 commit 178d9ba
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 37 deletions.
37 changes: 20 additions & 17 deletions pyscf/df/hessian/rhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
# (10|0)(0|10)
int2c = auxmol.intor('int2c2e', aosym='s1')
solve_j2c = _gen_metric_solver(int2c)
int2c_ip1 = auxmol.intor('int2c2e_ip1', aosym='s1')
if hessobj.auxbasis_response:
int2c_ip1 = auxmol.intor('int2c2e_ip1', aosym='s1')

rhoj0_P = 0
if hessobj.max_memory*.8e6/8 < naux*nocc*(nocc+nao):
Expand Down Expand Up @@ -154,21 +155,22 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
_load_dim0(rhok_ip1_PkI, p0, p1))
int3c_ip1 = None

get_int3c_ip2 = _int3c_wrapper(mol, auxmol, 'int3c2e_ip2', 's1')
wj_ip2 = np.empty((naux,3))
wk_ip2_Ipk = ftmp.create_dataset('wk_ip2', (nao,naux,3,nao), 'f8')
if hessobj.auxbasis_response > 1:
wk_ip2_P__ = np.empty((naux,3,nocc,nocc))
for shl0, shl1, nL in aux_ranges:
shls_slice = (0, nbas, 0, nbas, shl0, shl1)
p0, p1 = aux_loc[shl0], aux_loc[shl1]
int3c_ip2 = get_int3c_ip2(shls_slice)
wj_ip2[p0:p1] = np.einsum('yklp,lk->py', int3c_ip2, dm0)
if with_k:
wk_ip2_Ipk[:,p0:p1] = lib.einsum('yklp,il->ipyk', int3c_ip2, dm0)
if hessobj.auxbasis_response > 1:
wk_ip2_P__[p0:p1] = lib.einsum('xuvp,ui,vj->pxij', int3c_ip2, mocc_2, mocc_2)
int3c_ip2 = None
if hessobj.auxbasis_response:
get_int3c_ip2 = _int3c_wrapper(mol, auxmol, 'int3c2e_ip2', 's1')
wj_ip2 = np.empty((naux,3))
wk_ip2_Ipk = ftmp.create_dataset('wk_ip2', (nao,naux,3,nao), 'f8')
if hessobj.auxbasis_response > 1:
wk_ip2_P__ = np.empty((naux,3,nocc,nocc))
for shl0, shl1, nL in aux_ranges:
shls_slice = (0, nbas, 0, nbas, shl0, shl1)
p0, p1 = aux_loc[shl0], aux_loc[shl1]
int3c_ip2 = get_int3c_ip2(shls_slice)
wj_ip2[p0:p1] = np.einsum('yklp,lk->py', int3c_ip2, dm0)
if with_k:
wk_ip2_Ipk[:,p0:p1] = lib.einsum('yklp,il->ipyk', int3c_ip2, dm0)
if hessobj.auxbasis_response > 1:
wk_ip2_P__[p0:p1] = lib.einsum('xuvp,ui,vj->pxij', int3c_ip2, mocc_2, mocc_2)
int3c_ip2 = None

if hessobj.auxbasis_response > 1:
get_int3c_ipip2 = _int3c_wrapper(mol, auxmol, 'int3c2e_ipip2', 's1')
Expand Down Expand Up @@ -399,7 +401,8 @@ def _gen_jk(hessobj, mo_coeff, mo_occ, chkfile=None, atmlst=None,
int2c = auxmol.intor('int2c2e', aosym='s1')
solve_j2c = _gen_metric_solver(int2c)
int2c = None
int2c_ip1 = auxmol.intor('int2c2e_ip1', aosym='s1')
if hessobj.auxbasis_response:
int2c_ip1 = auxmol.intor('int2c2e_ip1', aosym='s1')
rhoj0_P = 0
if with_k:
rhok0_Pl_ = np.empty((naux,nao,nocc))
Expand Down
42 changes: 22 additions & 20 deletions pyscf/df/hessian/uhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
# (10|0)(0|10)
int2c = auxmol.intor('int2c2e', aosym='s1')
solve_j2c = _gen_metric_solver(int2c)
int2c_ip1 = auxmol.intor('int2c2e_ip1', aosym='s1')
if hessobj.auxbasis_response:
int2c_ip1 = auxmol.intor('int2c2e_ip1', aosym='s1')

rhoj0_P = 0
if with_k:
Expand Down Expand Up @@ -171,25 +172,26 @@ def _partial_hess_ejk(hessobj, mo_energy=None, mo_coeff=None, mo_occ=None,
vk2a_buf += lib.einsum('xijp,pkjy->xyki', int3c_ip1, _load_dim0(rhoka_ip1_PkI, p0, p1))
vk2b_buf += lib.einsum('xijp,pkjy->xyki', int3c_ip1, _load_dim0(rhokb_ip1_PkI, p0, p1))

get_int3c_ip2 = _int3c_wrapper(mol, auxmol, 'int3c2e_ip2', 's1')
wj_ip2 = np.empty((naux,3))
wka_ip2_Ipk = ftmp.create_dataset('wka_ip2', (nao,naux,3,nao), 'f8')
wkb_ip2_Ipk = ftmp.create_dataset('wkb_ip2', (nao,naux,3,nao), 'f8')
if hessobj.auxbasis_response > 1:
wka_ip2_P__ = np.empty((naux,3,nocca,nocca))
wkb_ip2_P__ = np.empty((naux,3,noccb,noccb))
for shl0, shl1, nL in aux_ranges:
shls_slice = (0, nbas, 0, nbas, shl0, shl1)
p0, p1 = aux_loc[shl0], aux_loc[shl1]
int3c_ip2 = get_int3c_ip2(shls_slice)
wj_ip2[p0:p1] = np.einsum('yklp,lk->py', int3c_ip2, dm0)
if with_k:
wka_ip2_Ipk[:,p0:p1] = lib.einsum('yklp,il->ipyk', int3c_ip2, dm0a)
wkb_ip2_Ipk[:,p0:p1] = lib.einsum('yklp,il->ipyk', int3c_ip2, dm0b)
if hessobj.auxbasis_response > 1:
wka_ip2_P__[p0:p1] = lib.einsum('xuvp,ui,vj->pxij', int3c_ip2, mocca, mocca)
wkb_ip2_P__[p0:p1] = lib.einsum('xuvp,ui,vj->pxij', int3c_ip2, moccb, moccb)
int3c_ip2 = None
if hessobj.auxbasis_response:
get_int3c_ip2 = _int3c_wrapper(mol, auxmol, 'int3c2e_ip2', 's1')
wj_ip2 = np.empty((naux,3))
wka_ip2_Ipk = ftmp.create_dataset('wka_ip2', (nao,naux,3,nao), 'f8')
wkb_ip2_Ipk = ftmp.create_dataset('wkb_ip2', (nao,naux,3,nao), 'f8')
if hessobj.auxbasis_response > 1:
wka_ip2_P__ = np.empty((naux,3,nocca,nocca))
wkb_ip2_P__ = np.empty((naux,3,noccb,noccb))
for shl0, shl1, nL in aux_ranges:
shls_slice = (0, nbas, 0, nbas, shl0, shl1)
p0, p1 = aux_loc[shl0], aux_loc[shl1]
int3c_ip2 = get_int3c_ip2(shls_slice)
wj_ip2[p0:p1] = np.einsum('yklp,lk->py', int3c_ip2, dm0)
if with_k:
wka_ip2_Ipk[:,p0:p1] = lib.einsum('yklp,il->ipyk', int3c_ip2, dm0a)
wkb_ip2_Ipk[:,p0:p1] = lib.einsum('yklp,il->ipyk', int3c_ip2, dm0b)
if hessobj.auxbasis_response > 1:
wka_ip2_P__[p0:p1] = lib.einsum('xuvp,ui,vj->pxij', int3c_ip2, mocca, mocca)
wkb_ip2_P__[p0:p1] = lib.einsum('xuvp,ui,vj->pxij', int3c_ip2, moccb, moccb)
int3c_ip2 = None

if hessobj.auxbasis_response > 1:
get_int3c_ipip2 = _int3c_wrapper(mol, auxmol, 'int3c2e_ipip2', 's1')
Expand Down

0 comments on commit 178d9ba

Please sign in to comment.