Skip to content

Commit

Permalink
Fix transform_ci for more than 64 orbitals (pyscf#2095)
Browse files Browse the repository at this point in the history
* Fix transform_ci for more than 64 orbitals

* Separate occ_masks into a function
  • Loading branch information
vyu16 authored Feb 28, 2024
1 parent fb49e40 commit e2cc8c1
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions pyscf/fci/addons.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,22 +658,15 @@ def transform_ci(ci, nelec, u):
nb_new = cistring.num_strings(norb_new, nelecb)
ci = ci.reshape(na_old, nb_old)

one_particle_strs_old = numpy.asarray([1 << i for i in range(norb_old)])
one_particle_strs_new = numpy.asarray([1 << i for i in range(norb_new)])

if neleca == 0:
trans_ci_a = numpy.ones((1, 1))
else:
trans_ci_a = numpy.zeros((na_old, na_new), dtype=ua.dtype)
strs_old = numpy.asarray(cistring.make_strings(range(norb_old), neleca))

# Unitary transformation array trans_ci is the overlap between two sets of CI basis.
occ_masks_old = (strs_old[:,None] & one_particle_strs_old) != 0
occ_masks_old = _init_occ_masks(norb_old, neleca, na_old)
if norb_old == norb_new:
occ_masks_new = occ_masks_old
else:
strs_new = numpy.asarray(cistring.make_strings(range(norb_new), neleca))
occ_masks_new = (strs_new[:,None] & one_particle_strs_new) != 0
occ_masks_new = _init_occ_masks(norb_new, neleca, na_new)

# Perform
#for i in range(na_old): # old basis
Expand All @@ -692,14 +685,11 @@ def transform_ci(ci, nelec, u):
trans_ci_b = numpy.ones((1, 1))
else:
trans_ci_b = numpy.zeros((nb_old, nb_new), dtype=ub.dtype)
strs_old = numpy.asarray(cistring.make_strings(range(norb_old), nelecb))

occ_masks_old = (strs_old[:,None] & one_particle_strs_old) != 0
occ_masks_old = _init_occ_masks(norb_old, nelecb, nb_old)
if norb_old == norb_new:
occ_masks_new = occ_masks_old
else:
strs_new = numpy.asarray(cistring.make_strings(range(norb_new), nelecb))
occ_masks_new = (strs_new[:,None] & one_particle_strs_new) != 0
occ_masks_new = _init_occ_masks(norb_new, nelecb, nb_new)

occ_idx_all_strs = numpy.where(occ_masks_new)[1].reshape(nb_new,nelecb)
for i in range(nb_old):
Expand All @@ -725,4 +715,17 @@ def _unpack_nelec(nelec, spin=None):
nelec = neleca, nelecb
return nelec

def _init_occ_masks(norb, nelec, nci):
one_particle_strs = numpy.asarray(cistring.make_strings(range(norb), 1))
strs = numpy.asarray(cistring.make_strings(range(norb), nelec))
if norb < 64:
occ_masks = (strs[:,None] & one_particle_strs) != 0
else:
occ_masks = numpy.zeros((nci, norb), dtype=bool)
for i in range(nci):
for j in range(norb):
if one_particle_strs[j][0] in strs[i]:
occ_masks[i,j] = True
return occ_masks

del (LARGE_CI_TOL, RETURN_STRS, PENALTY)

0 comments on commit e2cc8c1

Please sign in to comment.