Skip to content

Commit

Permalink
Fci str2addr (pyscf#1971)
Browse files Browse the repository at this point in the history
* Update cistring module for issue pyscf#1886

* Slightly optimize the fci contraction performance
  • Loading branch information
sunqm authored Nov 24, 2023
1 parent b27e4f9 commit ad72f25
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 58 deletions.
48 changes: 39 additions & 9 deletions pyscf/fci/cistring.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,7 @@ def _occslst2strs(occslst):
class OIndexList(numpy.ndarray):
pass

def num_strings(n, m):
if m < 0 or m > n:
return 0
else:
return math.factorial(n) // (math.factorial(n-m)*math.factorial(m))
num_strings = lib.comb

def gen_linkstr_index_o1(orb_list, nelec, strs=None, tril=False):
'''Look up table, for the strings relationship in terms of a
Expand Down Expand Up @@ -345,12 +341,30 @@ def count_bit1(n):

def addr2str(norb, nelec, addr):
'''Convert CI determinant address to string'''
return addrs2str(norb, nelec, [addr])[0]
if norb >= 64:
raise NotImplementedError('norb >= 64')
assert num_strings(norb, nelec) > addr

if addr < 2**31:
return addrs2str(norb, nelec, [addr])[0]

return _addr2str(norb, nelec, addr)

def _addr2str(norb, nelec, addr):
if addr == 0 or nelec == norb or nelec == 0:
return (1 << nelec) - 1 # ..0011..11

for i in reversed(range(norb)):
addrcum = num_strings(i, nelec)
if addrcum <= addr:
return (1 << i) | _addr2str(i, nelec-1, addr-addrcum)

def addrs2str(norb, nelec, addrs):
'''Convert a list of CI determinant address to string'''
if norb >= 64:
raise NotImplementedError('norb >= 64')
if num_strings(norb, nelec) >= 2**31:
raise NotImplementedError('Large address')

addrs = numpy.asarray(addrs, dtype=numpy.int32)
assert (all(num_strings(norb, nelec) > addrs))
Expand All @@ -372,13 +386,29 @@ def str2addr(norb, nelec, string):
string = int(string, 2)
else:
assert (bin(string).count('1') == nelec)
libfci.FCIstr2addr.restype = ctypes.c_int
return libfci.FCIstr2addr(ctypes.c_int(norb), ctypes.c_int(nelec),
ctypes.c_ulonglong(string))

if num_strings(norb, nelec) < 2**31:
libfci.FCIstr2addr.restype = ctypes.c_int
return libfci.FCIstr2addr(ctypes.c_int(norb), ctypes.c_int(nelec),
ctypes.c_ulonglong(string))
return _str2addr(norb, nelec, string)

def _str2addr(norb, nelec, string):
if norb <= nelec or nelec == 0:
return 0
addr = 0
for orbital_id in reversed(range(norb)):
if (1 << orbital_id) & string:
addr += num_strings(orbital_id, nelec)
nelec -= 1
return addr

def strs2addr(norb, nelec, strings):
'''Convert a list of string to CI determinant address'''
if norb >= 64:
raise NotImplementedError('norb >= 64')
if num_strings(norb, nelec) >= 2**31:
raise NotImplementedError('Large address')

strings = numpy.asarray(strings, dtype=numpy.int64)
count = strings.size
Expand Down
38 changes: 23 additions & 15 deletions pyscf/lib/mcscf/fci_contract.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
#include "vhf/fblas.h"
#include "np_helper/np_helper.h"
#include "fci.h"
// for (16e,16o) ~ 11 MB buffer = 120 * 12870 * 8
#define STRB_BLKSIZE 112
// optimized for 512 KB L2 cache, (16e,16o)
#define STRB_BLKSIZE 160

/*
* CPU timing of single thread can be estimated:
Expand Down Expand Up @@ -291,7 +291,7 @@ static void spread_bufa_t1(double *ci1, double *t1, int nrow_t1,
* bcount_for_spread_a is different for spin1 and spin0
*/
static void ctr_rhf2e_kern(double *eri, double *ci0, double *ci1,
double *ci1buf, double *t1buf,
double *ci1buf, double *t1, double *vt1,
int bcount_for_spread_a, int ncol_ci1buf,
int bcount, int stra_id, int strb_id,
int norb, int na, int nb, int nlinka, int nlinkb,
Expand All @@ -301,8 +301,6 @@ static void ctr_rhf2e_kern(double *eri, double *ci0, double *ci1,
const double D0 = 0;
const double D1 = 1;
const int nnorb = norb * (norb+1)/2;
double *t1 = t1buf;
double *vt1 = t1buf + nnorb*bcount;

NPdset0(t1, nnorb*bcount);
FCIprog_a_t1(ci0, t1, bcount, stra_id, strb_id,
Expand Down Expand Up @@ -371,7 +369,11 @@ void FCIcontract_2e_spin0(double *eri, double *ci0, double *ci1,
{
int strk, ib;
size_t blen;
double *t1buf = malloc(sizeof(double) * (STRB_BLKSIZE*norb*(norb+1)+2));
int nnorb = norb*(norb+1)/2;
double *t1buf = malloc(sizeof(double) * (STRB_BLKSIZE*nnorb*2+2));
double *tmp;
double *t1 = t1buf;
double *vt1 = t1buf + nnorb*STRB_BLKSIZE;
double *ci1buf = malloc(sizeof(double) * (na*STRB_BLKSIZE+2));
ci1bufs[omp_get_thread_num()] = ci1buf;
for (ib = 0; ib < na; ib += STRB_BLKSIZE) {
Expand All @@ -380,15 +382,16 @@ void FCIcontract_2e_spin0(double *eri, double *ci0, double *ci1,
#pragma omp for schedule(static, 112)
/* strk starts from MAX(strk0, ib), because [0:ib,0:ib] have been evaluated */
for (strk = ib; strk < na; strk++) {
ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1buf,
ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1, vt1,
MIN(STRB_BLKSIZE, strk-ib), blen,
MIN(STRB_BLKSIZE, strk+1-ib),
strk, ib, norb, na, na, nlink, nlink,
clink, clink);
// swap buffer for better cache utilization in next task
tmp = t1;
t1 = vt1;
vt1 = tmp;
}
// NPomp_dsum_reduce_inplace(ci1bufs, blen*na);
//#pragma omp master
// FCIaxpy2d(ci1+ib, ci1buf, na, na, blen);
#pragma omp barrier
_reduce(ci1+ib, ci1bufs, na, na, blen);
// An explicit barrier to ensure ci1 is updated. Without barrier, there may
Expand Down Expand Up @@ -417,22 +420,27 @@ void FCIcontract_2e_spin1(double *eri, double *ci0, double *ci1,
{
int strk, ib;
size_t blen;
double *t1buf = malloc(sizeof(double) * (STRB_BLKSIZE*norb*(norb+1)+2));
int nnorb = norb*(norb+1)/2;
double *t1buf = malloc(sizeof(double) * (STRB_BLKSIZE*nnorb*2+2));
double *tmp;
double *t1 = t1buf;
double *vt1 = t1buf + nnorb*STRB_BLKSIZE;
double *ci1buf = malloc(sizeof(double) * (na*STRB_BLKSIZE+2));
ci1bufs[omp_get_thread_num()] = ci1buf;
for (ib = 0; ib < nb; ib += STRB_BLKSIZE) {
blen = MIN(STRB_BLKSIZE, nb-ib);
NPdset0(ci1buf, ((size_t)na) * blen);
#pragma omp for schedule(static)
for (strk = 0; strk < na; strk++) {
ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1buf,
ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1, vt1,
blen, blen, blen, strk, ib,
norb, na, nb, nlinka, nlinkb,
clinka, clinkb);
// swap buffer for better cache utilization in next task
tmp = t1;
t1 = vt1;
vt1 = tmp;
}
// NPomp_dsum_reduce_inplace(ci1bufs, blen*na);
//#pragma omp master
// FCIaxpy2d(ci1+ib, ci1buf, na, nb, blen);
#pragma omp barrier
_reduce(ci1+ib, ci1bufs, na, nb, blen);
// An explicit barrier to ensure ci1 is updated. Without barrier, there may
Expand Down
16 changes: 11 additions & 5 deletions pyscf/lib/mcscf/fci_contract_nosym.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
#include "np_helper/np_helper.h"
#include "fci.h"
#define CSUMTHR 1e-28
#define STRB_BLKSIZE 112
// optimized for 1 MB L2 cache, (16e,16o)
#define STRB_BLKSIZE 120

double FCI_t1ci_sf(double *ci0, double *t1, int bcount,
int stra_id, int strb_id,
Expand Down Expand Up @@ -147,7 +148,7 @@ static void spread_b_t1(double *ci1, double *t1,
}

static void ctr_rhf2e_kern(double *eri, double *ci0, double *ci1,
double *ci1buf, double *t1buf,
double *ci1buf, double *t1, double *vt1,
int bcount_for_spread_a, int ncol_ci1buf,
int bcount, int stra_id, int strb_id,
int norb, int na, int nb, int nlinka, int nlinkb,
Expand All @@ -157,8 +158,6 @@ static void ctr_rhf2e_kern(double *eri, double *ci0, double *ci1,
const double D0 = 0;
const double D1 = 1;
const int nnorb = norb * norb;
double *t1 = t1buf;
double *vt1 = t1buf + nnorb*bcount;
double csum;

csum = FCI_t1ci_sf(ci0, t1, bcount, stra_id, strb_id,
Expand Down Expand Up @@ -203,16 +202,23 @@ void FCIcontract_2es1(double *eri, double *ci0, double *ci1,
{
int strk, ib, blen;
double *t1buf = malloc(sizeof(double) * (STRB_BLKSIZE*norb*norb*2+2));
double *tmp;
double *t1 = t1buf;
double *vt1 = t1buf + norb*norb*STRB_BLKSIZE;
double *ci1buf = malloc(sizeof(double) * (na*STRB_BLKSIZE+2));
for (ib = 0; ib < nb; ib += STRB_BLKSIZE) {
blen = MIN(STRB_BLKSIZE, nb-ib);
NPdset0(ci1buf, ((size_t)na) * blen);
#pragma omp for schedule(static)
for (strk = 0; strk < na; strk++) {
ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1buf,
ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1, vt1,
blen, blen, blen, strk, ib,
norb, na, nb, nlinka, nlinkb,
clinka, clinkb);
// swap buffer for better cache utilization in next task
tmp = t1;
t1 = vt1;
vt1 = tmp;
}
#pragma omp critical
axpy2d(ci1+ib, ci1buf, na, nb, blen);
Expand Down
62 changes: 33 additions & 29 deletions pyscf/lib/mcscf/fci_string.c
Original file line number Diff line number Diff line change
Expand Up @@ -97,41 +97,45 @@ int FCIdes_sign(int p, uint64_t string0)
}
}

// [math.comb(n, m) for n in range(1, 21) for m in range(n)]
static int _binomial_cache[] = {
1,
1, 2,
1, 3, 3,
1, 4, 6, 4,
1, 5, 10, 10, 5,
1, 6, 15, 20, 15, 6,
1, 7, 21, 35, 35, 21, 7,
1, 8, 28, 56, 70, 56, 28, 8,
1, 9, 36, 84, 126, 126, 84, 36, 9,
1, 10, 45, 120, 210, 252, 210, 120, 45, 10,
1, 11, 55, 165, 330, 462, 462, 330, 165, 55, 11,
1, 12, 66, 220, 495, 792, 924, 792, 495, 220, 66, 12,
1, 13, 78, 286, 715, 1287, 1716, 1716, 1287, 715, 286, 78, 13,
1, 14, 91, 364, 1001, 2002, 3003, 3432, 3003, 2002, 1001, 364, 91, 14,
1, 15, 105, 455, 1365, 3003, 5005, 6435, 6435, 5005, 3003, 1365, 455, 105, 15,
1, 16, 120, 560, 1820, 4368, 8008, 11440, 12870, 11440, 8008, 4368, 1820, 560, 120, 16,
1, 17, 136, 680, 2380, 6188, 12376, 19448, 24310, 24310, 19448, 12376, 6188, 2380, 680, 136, 17,
1, 18, 153, 816, 3060, 8568, 18564, 31824, 43758, 48620, 43758, 31824, 18564, 8568, 3060, 816, 153, 18,
1, 19, 171, 969, 3876, 11628, 27132, 50388, 75582, 92378, 92378, 75582, 50388, 27132, 11628, 3876, 969, 171, 19,
1, 20, 190, 1140, 4845, 15504, 38760, 77520, 125970, 167960, 184756, 167960, 125970, 77520, 38760, 15504, 4845, 1140, 190, 20,
};
static int binomial(int n, int m)
{
int i;
if (m >= n) {
return 1;
} else if (n < 28) {
uint64_t num = 1;
uint64_t div = 1;
if (m+m >= n) {
for (i = 0; i < n-m; i++) {
num *= m+i+1;
div *= i+1;
}
} else {
for (i = 0; i < m; i++) {
num *= (n-m)+i+1;
div *= i+1;
}
}
return num / div;
} else if (n <= 20) {
return _binomial_cache[n*(n-1)/2+m];
} else {
double dnum = 1;
double ddiv = 1;
if (m+m >= n) {
for (i = 0; i < n-m; i++) {
dnum *= m+i+1;
ddiv *= i+1;
}
} else {
for (i = 0; i < m; i++) {
dnum *= (n-m)+i+1;
ddiv *= i+1;
}
if (m*2 <= n) {
m = n - m;
}
uint64_t i;
uint64_t val = 1;
for (i = m; i <= n; i++) {
val *= i;
val /= i - m;
}
return (int)(dnum / ddiv);
}
}

Expand Down
2 changes: 2 additions & 0 deletions pyscf/lib/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ def prange_split(n_total, n_sections):
else:
import math
def comb(n, k):
if k < 0 or k > n:
return 0
return math.factorial(n) // math.factorial(n-k) // math.factorial(k)

def map_with_prefetch(func, *iterables):
Expand Down

0 comments on commit ad72f25

Please sign in to comment.