Skip to content

Commit

Permalink
Update cistring module for issue pyscf#1886
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Nov 20, 2023
1 parent fbd9220 commit 1dbf044
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 38 deletions.
48 changes: 39 additions & 9 deletions pyscf/fci/cistring.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,7 @@ def _occslst2strs(occslst):
class OIndexList(numpy.ndarray):
pass

def num_strings(n, m):
if m < 0 or m > n:
return 0
else:
return math.factorial(n) // (math.factorial(n-m)*math.factorial(m))
num_strings = lib.comb

def gen_linkstr_index_o1(orb_list, nelec, strs=None, tril=False):
'''Look up table, for the strings relationship in terms of a
Expand Down Expand Up @@ -345,12 +341,30 @@ def count_bit1(n):

def addr2str(norb, nelec, addr):
'''Convert CI determinant address to string'''
return addrs2str(norb, nelec, [addr])[0]
if norb >= 64:
raise NotImplementedError('norb >= 64')
assert num_strings(norb, nelec) > addr

if addr < 2**31:
return addrs2str(norb, nelec, [addr])[0]

return _addr2str(norb, nelec, addr)

def _addr2str(norb, nelec, addr):
if addr == 0 or nelec == norb or nelec == 0:
return (1 << nelec) - 1 # ..0011..11

for i in reversed(range(norb)):
addrcum = num_strings(i, nelec)
if addrcum <= addr:
return (1 << i) | _addr2str(i, nelec-1, addr-addrcum)

def addrs2str(norb, nelec, addrs):
'''Convert a list of CI determinant address to string'''
if norb >= 64:
raise NotImplementedError('norb >= 64')
if num_strings(norb, nelec) >= 2**31:
raise NotImplementedError('Large address')

addrs = numpy.asarray(addrs, dtype=numpy.int32)
assert (all(num_strings(norb, nelec) > addrs))
Expand All @@ -372,13 +386,29 @@ def str2addr(norb, nelec, string):
string = int(string, 2)
else:
assert (bin(string).count('1') == nelec)
libfci.FCIstr2addr.restype = ctypes.c_int
return libfci.FCIstr2addr(ctypes.c_int(norb), ctypes.c_int(nelec),
ctypes.c_ulonglong(string))

if num_strings(norb, nelec) < 2**31:
libfci.FCIstr2addr.restype = ctypes.c_int
return libfci.FCIstr2addr(ctypes.c_int(norb), ctypes.c_int(nelec),
ctypes.c_ulonglong(string))
return _str2addr(norb, nelec, string)

def _str2addr(norb, nelec, string):
if norb <= nelec or nelec == 0:
return 0
addr = 0
for orbital_id in reversed(range(norb)):
if (1 << orbital_id) & string:
addr += num_strings(orbital_id, nelec)
nelec -= 1
return addr

def strs2addr(norb, nelec, strings):
'''Convert a list of string to CI determinant address'''
if norb >= 64:
raise NotImplementedError('norb >= 64')
if num_strings(norb, nelec) >= 2**31:
raise NotImplementedError('Large address')

strings = numpy.asarray(strings, dtype=numpy.int64)
count = strings.size
Expand Down
62 changes: 33 additions & 29 deletions pyscf/lib/mcscf/fci_string.c
Original file line number Diff line number Diff line change
Expand Up @@ -97,41 +97,45 @@ int FCIdes_sign(int p, uint64_t string0)
}
}

// [math.comb(n, m) for n in range(1, 21) for m in range(n)]
static int _binomial_cache[] = {
1,
1, 2,
1, 3, 3,
1, 4, 6, 4,
1, 5, 10, 10, 5,
1, 6, 15, 20, 15, 6,
1, 7, 21, 35, 35, 21, 7,
1, 8, 28, 56, 70, 56, 28, 8,
1, 9, 36, 84, 126, 126, 84, 36, 9,
1, 10, 45, 120, 210, 252, 210, 120, 45, 10,
1, 11, 55, 165, 330, 462, 462, 330, 165, 55, 11,
1, 12, 66, 220, 495, 792, 924, 792, 495, 220, 66, 12,
1, 13, 78, 286, 715, 1287, 1716, 1716, 1287, 715, 286, 78, 13,
1, 14, 91, 364, 1001, 2002, 3003, 3432, 3003, 2002, 1001, 364, 91, 14,
1, 15, 105, 455, 1365, 3003, 5005, 6435, 6435, 5005, 3003, 1365, 455, 105, 15,
1, 16, 120, 560, 1820, 4368, 8008, 11440, 12870, 11440, 8008, 4368, 1820, 560, 120, 16,
1, 17, 136, 680, 2380, 6188, 12376, 19448, 24310, 24310, 19448, 12376, 6188, 2380, 680, 136, 17,
1, 18, 153, 816, 3060, 8568, 18564, 31824, 43758, 48620, 43758, 31824, 18564, 8568, 3060, 816, 153, 18,
1, 19, 171, 969, 3876, 11628, 27132, 50388, 75582, 92378, 92378, 75582, 50388, 27132, 11628, 3876, 969, 171, 19,
1, 20, 190, 1140, 4845, 15504, 38760, 77520, 125970, 167960, 184756, 167960, 125970, 77520, 38760, 15504, 4845, 1140, 190, 20,
};
static int binomial(int n, int m)
{
int i;
if (m >= n) {
return 1;
} else if (n < 28) {
uint64_t num = 1;
uint64_t div = 1;
if (m+m >= n) {
for (i = 0; i < n-m; i++) {
num *= m+i+1;
div *= i+1;
}
} else {
for (i = 0; i < m; i++) {
num *= (n-m)+i+1;
div *= i+1;
}
}
return num / div;
} else if (n <= 20) {
return _binomial_cache[n*(n-1)/2+m];
} else {
double dnum = 1;
double ddiv = 1;
if (m+m >= n) {
for (i = 0; i < n-m; i++) {
dnum *= m+i+1;
ddiv *= i+1;
}
} else {
for (i = 0; i < m; i++) {
dnum *= (n-m)+i+1;
ddiv *= i+1;
}
if (m*2 <= n) {
m = n - m;
}
uint64_t i;
uint64_t val = 1;
for (i = m; i <= n; i++) {
val *= i;
val /= i - m;
}
return (int)(dnum / ddiv);
}
}

Expand Down
2 changes: 2 additions & 0 deletions pyscf/lib/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ def prange_split(n_total, n_sections):
else:
import math
def comb(n, k):
if k < 0 or k > n:
return 0
return math.factorial(n) // math.factorial(n-k) // math.factorial(k)

def map_with_prefetch(func, *iterables):
Expand Down

0 comments on commit 1dbf044

Please sign in to comment.