Skip to content

Commit

Permalink
Fix _vhf.direct_mapdm
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Sep 29, 2023
1 parent 7a07113 commit cf99035
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
1 change: 0 additions & 1 deletion pyscf/grad/test/test_h2o.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,3 @@ def test_roks_b3lypg(self):
if __name__ == "__main__":
print("Full Tests for H2O")
unittest.main()

20 changes: 13 additions & 7 deletions pyscf/scf/_vhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,8 @@ def direct(dms, atm, bas, env, vhfopt=None, hermi=0, cart=False,
vk = vk.reshape(dms_shape)
return vj, vk

# call all fjk for each dm, the return array has len(dms)*len(jkdescript)*ncomp components
# call all fjk for each dm. The return has the shape
# [len(jkdescript),len(dms),ncomp,nao,nao]
# jkdescript: 'ij->s1kl', 'kl->s2ij', ...
def direct_mapdm(intor, aosym, jkdescript,
dms, ncomp, atm, bas, env, vhfopt=None, cintopt=None,
Expand Down Expand Up @@ -504,11 +505,13 @@ def direct_mapdm(intor, aosym, jkdescript,
if ncomp == 1:
vjk = [v[0] for v in vjk]

# vjk.reshape(n_jk,n_dm,...).transpose(1,0,...)
if n_dm > 1 and n_jk > 1:
vjk = [vjk[i::n_dm] for i in range(n_dm)]
elif n_jk == 1 and single_dm:
vjk = vjk[0]
if single_dm:
if isinstance(jkdescript, str):
vjk = vjk[0]
elif isinstance(jkdescript, str):
vjk = numpy.asarray(vjk)
else: # n_jk > 1
vjk = [numpy.asarray(vjk[i*n_dm:(i+1)*n_dm]) for i in range(n_jk)]
return vjk

# for density matrices in dms, bind each dm to a jk operator
Expand Down Expand Up @@ -598,7 +601,10 @@ def nr_direct_drv(intor, aosym, jkscript,
if out is None:
buf = numpy.empty(vshape)
else:
buf = numpy.ndarray(vshape, dtype=numpy.double, buffer=out[i])
buf = out[i]
assert buf.shape == vshape
assert buf.dtype == numpy.double
assert buf.flags.c_contiguous
vjk.append(buf)
dmsptr[i] = dms[i].ctypes.data_as(ctypes.c_void_p)
vjkptr[i] = vjk[i].ctypes.data_as(ctypes.c_void_p)
Expand Down

0 comments on commit cf99035

Please sign in to comment.