diff --git a/pyscf/fci/addons.py b/pyscf/fci/addons.py index 95fda3f6ce..3f794f1a63 100644 --- a/pyscf/fci/addons.py +++ b/pyscf/fci/addons.py @@ -658,22 +658,15 @@ def transform_ci(ci, nelec, u): nb_new = cistring.num_strings(norb_new, nelecb) ci = ci.reshape(na_old, nb_old) - one_particle_strs_old = numpy.asarray([1 << i for i in range(norb_old)]) - one_particle_strs_new = numpy.asarray([1 << i for i in range(norb_new)]) - if neleca == 0: trans_ci_a = numpy.ones((1, 1)) else: trans_ci_a = numpy.zeros((na_old, na_new), dtype=ua.dtype) - strs_old = numpy.asarray(cistring.make_strings(range(norb_old), neleca)) - - # Unitary transformation array trans_ci is the overlap between two sets of CI basis. - occ_masks_old = (strs_old[:,None] & one_particle_strs_old) != 0 + occ_masks_old = _init_occ_masks(norb_old, neleca, na_old) if norb_old == norb_new: occ_masks_new = occ_masks_old else: - strs_new = numpy.asarray(cistring.make_strings(range(norb_new), neleca)) - occ_masks_new = (strs_new[:,None] & one_particle_strs_new) != 0 + occ_masks_new = _init_occ_masks(norb_new, neleca, na_new) # Perform #for i in range(na_old): # old basis @@ -692,14 +685,11 @@ def transform_ci(ci, nelec, u): trans_ci_b = numpy.ones((1, 1)) else: trans_ci_b = numpy.zeros((nb_old, nb_new), dtype=ub.dtype) - strs_old = numpy.asarray(cistring.make_strings(range(norb_old), nelecb)) - - occ_masks_old = (strs_old[:,None] & one_particle_strs_old) != 0 + occ_masks_old = _init_occ_masks(norb_old, nelecb, nb_old) if norb_old == norb_new: occ_masks_new = occ_masks_old else: - strs_new = numpy.asarray(cistring.make_strings(range(norb_new), nelecb)) - occ_masks_new = (strs_new[:,None] & one_particle_strs_new) != 0 + occ_masks_new = _init_occ_masks(norb_new, nelecb, nb_new) occ_idx_all_strs = numpy.where(occ_masks_new)[1].reshape(nb_new,nelecb) for i in range(nb_old): @@ -725,4 +715,17 @@ def _unpack_nelec(nelec, spin=None): nelec = neleca, nelecb return nelec +def _init_occ_masks(norb, nelec, nci): + one_particle_strs = numpy.asarray(cistring.make_strings(range(norb), 1)) + strs = numpy.asarray(cistring.make_strings(range(norb), nelec)) + if norb < 64: + occ_masks = (strs[:,None] & one_particle_strs) != 0 + else: + occ_masks = numpy.zeros((nci, norb), dtype=bool) + for i in range(nci): + for j in range(norb): + if one_particle_strs[j][0] in strs[i]: + occ_masks[i,j] = True + return occ_masks + del (LARGE_CI_TOL, RETURN_STRS, PENALTY)