diff --git a/pyscf/fci/cistring.py b/pyscf/fci/cistring.py index fb074f23f6..983178ff54 100644 --- a/pyscf/fci/cistring.py +++ b/pyscf/fci/cistring.py @@ -127,11 +127,7 @@ def _occslst2strs(occslst): class OIndexList(numpy.ndarray): pass -def num_strings(n, m): - if m < 0 or m > n: - return 0 - else: - return math.factorial(n) // (math.factorial(n-m)*math.factorial(m)) +num_strings = lib.comb def gen_linkstr_index_o1(orb_list, nelec, strs=None, tril=False): '''Look up table, for the strings relationship in terms of a @@ -345,12 +341,30 @@ def count_bit1(n): def addr2str(norb, nelec, addr): '''Convert CI determinant address to string''' - return addrs2str(norb, nelec, [addr])[0] + if norb >= 64: + raise NotImplementedError('norb >= 64') + assert num_strings(norb, nelec) > addr + + if addr < 2**31: + return addrs2str(norb, nelec, [addr])[0] + + return _addr2str(norb, nelec, addr) + +def _addr2str(norb, nelec, addr): + if addr == 0 or nelec == norb or nelec == 0: + return (1 << nelec) - 1 # ..0011..11 + + for i in reversed(range(norb)): + addrcum = num_strings(i, nelec) + if addrcum <= addr: + return (1 << i) | _addr2str(i, nelec-1, addr-addrcum) def addrs2str(norb, nelec, addrs): '''Convert a list of CI determinant address to string''' if norb >= 64: raise NotImplementedError('norb >= 64') + if num_strings(norb, nelec) >= 2**31: + raise NotImplementedError('Large address') addrs = numpy.asarray(addrs, dtype=numpy.int32) assert (all(num_strings(norb, nelec) > addrs)) @@ -372,13 +386,29 @@ def str2addr(norb, nelec, string): string = int(string, 2) else: assert (bin(string).count('1') == nelec) - libfci.FCIstr2addr.restype = ctypes.c_int - return libfci.FCIstr2addr(ctypes.c_int(norb), ctypes.c_int(nelec), - ctypes.c_ulonglong(string)) + + if num_strings(norb, nelec) < 2**31: + libfci.FCIstr2addr.restype = ctypes.c_int + return libfci.FCIstr2addr(ctypes.c_int(norb), ctypes.c_int(nelec), + ctypes.c_ulonglong(string)) + return _str2addr(norb, nelec, string) + +def _str2addr(norb, nelec, string): + if norb <= nelec or nelec == 0: + return 0 + addr = 0 + for orbital_id in reversed(range(norb)): + if (1 << orbital_id) & string: + addr += num_strings(orbital_id, nelec) + nelec -= 1 + return addr + def strs2addr(norb, nelec, strings): '''Convert a list of string to CI determinant address''' if norb >= 64: raise NotImplementedError('norb >= 64') + if num_strings(norb, nelec) >= 2**31: + raise NotImplementedError('Large address') strings = numpy.asarray(strings, dtype=numpy.int64) count = strings.size diff --git a/pyscf/lib/mcscf/fci_contract.c b/pyscf/lib/mcscf/fci_contract.c index 4c05a00e9b..fe4af02261 100644 --- a/pyscf/lib/mcscf/fci_contract.c +++ b/pyscf/lib/mcscf/fci_contract.c @@ -26,8 +26,8 @@ #include "vhf/fblas.h" #include "np_helper/np_helper.h" #include "fci.h" -// for (16e,16o) ~ 11 MB buffer = 120 * 12870 * 8 -#define STRB_BLKSIZE 112 +// optimized for 512 KB L2 cache, (16e,16o) +#define STRB_BLKSIZE 160 /* * CPU timing of single thread can be estimated: @@ -291,7 +291,7 @@ static void spread_bufa_t1(double *ci1, double *t1, int nrow_t1, * bcount_for_spread_a is different for spin1 and spin0 */ static void ctr_rhf2e_kern(double *eri, double *ci0, double *ci1, - double *ci1buf, double *t1buf, + double *ci1buf, double *t1, double *vt1, int bcount_for_spread_a, int ncol_ci1buf, int bcount, int stra_id, int strb_id, int norb, int na, int nb, int nlinka, int nlinkb, @@ -301,8 +301,6 @@ static void ctr_rhf2e_kern(double *eri, double *ci0, double *ci1, const double D0 = 0; const double D1 = 1; const int nnorb = norb * (norb+1)/2; - double *t1 = t1buf; - double *vt1 = t1buf + nnorb*bcount; NPdset0(t1, nnorb*bcount); FCIprog_a_t1(ci0, t1, bcount, stra_id, strb_id, @@ -371,7 +369,11 @@ void FCIcontract_2e_spin0(double *eri, double *ci0, double *ci1, { int strk, ib; size_t blen; - double *t1buf = malloc(sizeof(double) * (STRB_BLKSIZE*norb*(norb+1)+2)); + int nnorb = norb*(norb+1)/2; + double *t1buf = malloc(sizeof(double) * (STRB_BLKSIZE*nnorb*2+2)); + double *tmp; + double *t1 = t1buf; + double *vt1 = t1buf + nnorb*STRB_BLKSIZE; double *ci1buf = malloc(sizeof(double) * (na*STRB_BLKSIZE+2)); ci1bufs[omp_get_thread_num()] = ci1buf; for (ib = 0; ib < na; ib += STRB_BLKSIZE) { @@ -380,15 +382,16 @@ void FCIcontract_2e_spin0(double *eri, double *ci0, double *ci1, #pragma omp for schedule(static, 112) /* strk starts from MAX(strk0, ib), because [0:ib,0:ib] have been evaluated */ for (strk = ib; strk < na; strk++) { - ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1buf, + ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1, vt1, MIN(STRB_BLKSIZE, strk-ib), blen, MIN(STRB_BLKSIZE, strk+1-ib), strk, ib, norb, na, na, nlink, nlink, clink, clink); + // swap buffer for better cache utilization in next task + tmp = t1; + t1 = vt1; + vt1 = tmp; } -// NPomp_dsum_reduce_inplace(ci1bufs, blen*na); -//#pragma omp master -// FCIaxpy2d(ci1+ib, ci1buf, na, na, blen); #pragma omp barrier _reduce(ci1+ib, ci1bufs, na, na, blen); // An explicit barrier to ensure ci1 is updated. Without barrier, there may @@ -417,7 +420,11 @@ void FCIcontract_2e_spin1(double *eri, double *ci0, double *ci1, { int strk, ib; size_t blen; - double *t1buf = malloc(sizeof(double) * (STRB_BLKSIZE*norb*(norb+1)+2)); + int nnorb = norb*(norb+1)/2; + double *t1buf = malloc(sizeof(double) * (STRB_BLKSIZE*nnorb*2+2)); + double *tmp; + double *t1 = t1buf; + double *vt1 = t1buf + nnorb*STRB_BLKSIZE; double *ci1buf = malloc(sizeof(double) * (na*STRB_BLKSIZE+2)); ci1bufs[omp_get_thread_num()] = ci1buf; for (ib = 0; ib < nb; ib += STRB_BLKSIZE) { @@ -425,14 +432,15 @@ void FCIcontract_2e_spin1(double *eri, double *ci0, double *ci1, NPdset0(ci1buf, ((size_t)na) * blen); #pragma omp for schedule(static) for (strk = 0; strk < na; strk++) { - ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1buf, + ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1, vt1, blen, blen, blen, strk, ib, norb, na, nb, nlinka, nlinkb, clinka, clinkb); + // swap buffer for better cache utilization in next task + tmp = t1; + t1 = vt1; + vt1 = tmp; } -// NPomp_dsum_reduce_inplace(ci1bufs, blen*na); -//#pragma omp master -// FCIaxpy2d(ci1+ib, ci1buf, na, nb, blen); #pragma omp barrier _reduce(ci1+ib, ci1bufs, na, nb, blen); // An explicit barrier to ensure ci1 is updated. Without barrier, there may diff --git a/pyscf/lib/mcscf/fci_contract_nosym.c b/pyscf/lib/mcscf/fci_contract_nosym.c index ddf12add55..bab54ba3ba 100644 --- a/pyscf/lib/mcscf/fci_contract_nosym.c +++ b/pyscf/lib/mcscf/fci_contract_nosym.c @@ -27,7 +27,8 @@ #include "np_helper/np_helper.h" #include "fci.h" #define CSUMTHR 1e-28 -#define STRB_BLKSIZE 112 +// optimized for 1 MB L2 cache, (16e,16o) +#define STRB_BLKSIZE 120 double FCI_t1ci_sf(double *ci0, double *t1, int bcount, int stra_id, int strb_id, @@ -147,7 +148,7 @@ static void spread_b_t1(double *ci1, double *t1, } static void ctr_rhf2e_kern(double *eri, double *ci0, double *ci1, - double *ci1buf, double *t1buf, + double *ci1buf, double *t1, double *vt1, int bcount_for_spread_a, int ncol_ci1buf, int bcount, int stra_id, int strb_id, int norb, int na, int nb, int nlinka, int nlinkb, @@ -157,8 +158,6 @@ static void ctr_rhf2e_kern(double *eri, double *ci0, double *ci1, const double D0 = 0; const double D1 = 1; const int nnorb = norb * norb; - double *t1 = t1buf; - double *vt1 = t1buf + nnorb*bcount; double csum; csum = FCI_t1ci_sf(ci0, t1, bcount, stra_id, strb_id, @@ -203,16 +202,23 @@ void FCIcontract_2es1(double *eri, double *ci0, double *ci1, { int strk, ib, blen; double *t1buf = malloc(sizeof(double) * (STRB_BLKSIZE*norb*norb*2+2)); + double *tmp; + double *t1 = t1buf; + double *vt1 = t1buf + norb*norb*STRB_BLKSIZE; double *ci1buf = malloc(sizeof(double) * (na*STRB_BLKSIZE+2)); for (ib = 0; ib < nb; ib += STRB_BLKSIZE) { blen = MIN(STRB_BLKSIZE, nb-ib); NPdset0(ci1buf, ((size_t)na) * blen); #pragma omp for schedule(static) for (strk = 0; strk < na; strk++) { - ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1buf, + ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1, vt1, blen, blen, blen, strk, ib, norb, na, nb, nlinka, nlinkb, clinka, clinkb); + // swap buffer for better cache utilization in next task + tmp = t1; + t1 = vt1; + vt1 = tmp; } #pragma omp critical axpy2d(ci1+ib, ci1buf, na, nb, blen); diff --git a/pyscf/lib/mcscf/fci_string.c b/pyscf/lib/mcscf/fci_string.c index 6e96158a9a..2110c493aa 100644 --- a/pyscf/lib/mcscf/fci_string.c +++ b/pyscf/lib/mcscf/fci_string.c @@ -97,41 +97,45 @@ int FCIdes_sign(int p, uint64_t string0) } } +// [math.comb(n, m) for n in range(1, 21) for m in range(n)] +static int _binomial_cache[] = { + 1, + 1, 2, + 1, 3, 3, + 1, 4, 6, 4, + 1, 5, 10, 10, 5, + 1, 6, 15, 20, 15, 6, + 1, 7, 21, 35, 35, 21, 7, + 1, 8, 28, 56, 70, 56, 28, 8, + 1, 9, 36, 84, 126, 126, 84, 36, 9, + 1, 10, 45, 120, 210, 252, 210, 120, 45, 10, + 1, 11, 55, 165, 330, 462, 462, 330, 165, 55, 11, + 1, 12, 66, 220, 495, 792, 924, 792, 495, 220, 66, 12, + 1, 13, 78, 286, 715, 1287, 1716, 1716, 1287, 715, 286, 78, 13, + 1, 14, 91, 364, 1001, 2002, 3003, 3432, 3003, 2002, 1001, 364, 91, 14, + 1, 15, 105, 455, 1365, 3003, 5005, 6435, 6435, 5005, 3003, 1365, 455, 105, 15, + 1, 16, 120, 560, 1820, 4368, 8008, 11440, 12870, 11440, 8008, 4368, 1820, 560, 120, 16, + 1, 17, 136, 680, 2380, 6188, 12376, 19448, 24310, 24310, 19448, 12376, 6188, 2380, 680, 136, 17, + 1, 18, 153, 816, 3060, 8568, 18564, 31824, 43758, 48620, 43758, 31824, 18564, 8568, 3060, 816, 153, 18, + 1, 19, 171, 969, 3876, 11628, 27132, 50388, 75582, 92378, 92378, 75582, 50388, 27132, 11628, 3876, 969, 171, 19, + 1, 20, 190, 1140, 4845, 15504, 38760, 77520, 125970, 167960, 184756, 167960, 125970, 77520, 38760, 15504, 4845, 1140, 190, 20, +}; static int binomial(int n, int m) { - int i; if (m >= n) { return 1; - } else if (n < 28) { - uint64_t num = 1; - uint64_t div = 1; - if (m+m >= n) { - for (i = 0; i < n-m; i++) { - num *= m+i+1; - div *= i+1; - } - } else { - for (i = 0; i < m; i++) { - num *= (n-m)+i+1; - div *= i+1; - } - } - return num / div; + } else if (n <= 20) { + return _binomial_cache[n*(n-1)/2+m]; } else { - double dnum = 1; - double ddiv = 1; - if (m+m >= n) { - for (i = 0; i < n-m; i++) { - dnum *= m+i+1; - ddiv *= i+1; - } - } else { - for (i = 0; i < m; i++) { - dnum *= (n-m)+i+1; - ddiv *= i+1; - } + if (m*2 <= n) { + m = n - m; + } + uint64_t i; + uint64_t val = 1; + for (i = m; i <= n; i++) { + val *= i; + val /= i - m; } - return (int)(dnum / ddiv); } } diff --git a/pyscf/lib/misc.py b/pyscf/lib/misc.py index e1113365c7..0b8e4a2dd0 100644 --- a/pyscf/lib/misc.py +++ b/pyscf/lib/misc.py @@ -348,6 +348,8 @@ def prange_split(n_total, n_sections): else: import math def comb(n, k): + if k < 0 or k > n: + return 0 return math.factorial(n) // math.factorial(n-k) // math.factorial(k) def map_with_prefetch(func, *iterables):