Skip to content

Commit

Permalink
Slightly optimize fci efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Feb 6, 2024
1 parent 6f8ce43 commit a73a24a
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 43 deletions.
75 changes: 55 additions & 20 deletions pyscf/lib/mcscf/fci_contract.c
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,15 @@ void FCIcontract_2e_spin0(double *eri, double *ci0, double *ci1,
free(clink);
}

static void axpy2d(double *out, double *in, int count, int no, int ni)
{
size_t i, j;
for (i = 0; i < count; i++) {
for (j = 0; j < ni; j++) {
out[i*no+j] += in[i*ni+j];
}
}
}

void FCIcontract_2e_spin1(double *eri, double *ci0, double *ci1,
int norb, int na, int nb, int nlinka, int nlinkb,
Expand All @@ -413,36 +422,62 @@ void FCIcontract_2e_spin1(double *eri, double *ci0, double *ci1,
FCIcompress_link_tril(clinkb, link_indexb, nb, nlinkb);

NPdset0(ci1, ((size_t)na) * nb);
double *ci1bufs[MAX_THREADS];
int strk_in_process[MAX_THREADS];
#pragma omp parallel
{
int strk, ib;
size_t blen;
int strk, ib, ib0, ka;
int nnorb = norb*(norb+1)/2;
double *t1buf = malloc(sizeof(double) * (STRB_BLKSIZE*nnorb*2+2));
double *tmp;
double *t1 = t1buf;
double *vt1 = t1buf + nnorb*STRB_BLKSIZE;
double *ci1buf = malloc(sizeof(double) * (na*STRB_BLKSIZE+2));
ci1bufs[omp_get_thread_num()] = ci1buf;
for (ib = 0; ib < nb; ib += STRB_BLKSIZE) {
blen = MIN(STRB_BLKSIZE, nb-ib);
NPdset0(ci1buf, ((size_t)na) * blen);
#pragma omp for schedule(static)
for (strk = 0; strk < na; strk++) {
ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1, vt1,
blen, blen, blen, strk, ib,
norb, na, nb, nlinka, nlinkb,
clinka, clinkb);
// swap buffer for better cache utilization in next task
tmp = t1;
t1 = vt1;
vt1 = tmp;
int n_threads = omp_get_num_threads();
int thread_id = omp_get_thread_num();
int chunk_size = STRB_BLKSIZE * n_threads;
int blen = STRB_BLKSIZE;
int next_thread = (thread_id + 1) % n_threads;
volatile int *unsafe = strk_in_process + next_thread;
int strk_offset = na/n_threads * thread_id;
if (n_threads == 1) {
unsafe = strk_in_process + 1;
*unsafe = -1;
}

for (ib0 = 0; ib0 < nb; ib0 += chunk_size) {
ib = ib0 + thread_id * STRB_BLKSIZE;
if (ib < nb) {
if (ib + STRB_BLKSIZE >= nb) { // the last block
blen = nb - ib;
if (thread_id > 0) {
unsafe = strk_in_process;
}
}

NPdset0(ci1buf, ((size_t)na) * blen);
strk_in_process[thread_id] = strk_offset;
for (ka = 0; ka < na; ka++) {
strk = (ka + strk_offset) % na;
strk_in_process[thread_id] = strk;
while (strk == *unsafe) {
// wait until safe to process the task
}
ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1, vt1,
blen, blen, blen, strk, ib,
norb, na, nb, nlinka, nlinkb,
clinka, clinkb);
// swap buffer for better cache utilization in next task
tmp = t1;
t1 = vt1;
vt1 = tmp;
}
// reset the status for each thread to release the spin lock
strk_in_process[thread_id] = -1;
}
#pragma omp barrier
_reduce(ci1+ib, ci1bufs, na, nb, blen);
// An explicit barrier to ensure ci1 is updated. Without barrier, there may
// occur race condition between FCIaxpy2d and ctr_rhf2e_kern
if (ib < nb) {
axpy2d(ci1+ib, ci1buf, na, nb, blen);
}
#pragma omp barrier
}
free(ci1buf);
Expand Down
75 changes: 52 additions & 23 deletions pyscf/lib/mcscf/fci_contract_nosym.c
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ static void ctr_rhf2e_kern(double *eri, double *ci0, double *ci1,

static void axpy2d(double *out, double *in, int count, int no, int ni)
{
int i, j;
size_t i, j;
for (i = 0; i < count; i++) {
for (j = 0; j < ni; j++) {
out[i*no+j] += in[i*ni+j];
Expand All @@ -195,33 +195,62 @@ void FCIcontract_2es1(double *eri, double *ci0, double *ci1,
FCIcompress_link(clinkb, link_indexb, norb, nb, nlinkb);

NPdset0(ci1, ((size_t)na) * nb);

#pragma omp parallel default(none) \
shared(eri, ci0, ci1, norb, na, nb, nlinka, nlinkb, \
clinka, clinkb)
int strk_in_process[MAX_THREADS];
#pragma omp parallel
{
int strk, ib, blen;
double *t1buf = malloc(sizeof(double) * (STRB_BLKSIZE*norb*norb*2+2));
int strk, ib, ib0, ka;
int nnorb = norb * norb;
double *t1buf = malloc(sizeof(double) * (STRB_BLKSIZE*nnorb*2+2));
double *tmp;
double *t1 = t1buf;
double *vt1 = t1buf + norb*norb*STRB_BLKSIZE;
double *vt1 = t1buf + nnorb*STRB_BLKSIZE;
double *ci1buf = malloc(sizeof(double) * (na*STRB_BLKSIZE+2));
for (ib = 0; ib < nb; ib += STRB_BLKSIZE) {
blen = MIN(STRB_BLKSIZE, nb-ib);
NPdset0(ci1buf, ((size_t)na) * blen);
#pragma omp for schedule(static)
for (strk = 0; strk < na; strk++) {
ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1, vt1,
blen, blen, blen, strk, ib,
norb, na, nb, nlinka, nlinkb,
clinka, clinkb);
// swap buffer for better cache utilization in next task
tmp = t1;
t1 = vt1;
vt1 = tmp;
int n_threads = omp_get_num_threads();
int thread_id = omp_get_thread_num();
int chunk_size = STRB_BLKSIZE * n_threads;
int blen = STRB_BLKSIZE;
int next_thread = (thread_id + 1) % n_threads;
volatile int *unsafe = strk_in_process + next_thread;
int strk_offset = na/n_threads * thread_id;
if (n_threads == 1) {
unsafe = strk_in_process + 1;
*unsafe = -1;
}

for (ib0 = 0; ib0 < nb; ib0 += chunk_size) {
ib = ib0 + thread_id * STRB_BLKSIZE;
if (ib < nb) {
if (ib + STRB_BLKSIZE >= nb) { // the last block
blen = nb - ib;
if (thread_id > 0) {
unsafe = strk_in_process;
}
}

NPdset0(ci1buf, ((size_t)na) * blen);
strk_in_process[thread_id] = strk_offset;
for (ka = 0; ka < na; ka++) {
strk = (ka + strk_offset) % na;
strk_in_process[thread_id] = strk;
while (strk == *unsafe) {
// wait until safe to process the task
}
ctr_rhf2e_kern(eri, ci0, ci1, ci1buf, t1, vt1,
blen, blen, blen, strk, ib,
norb, na, nb, nlinka, nlinkb,
clinka, clinkb);
// swap buffer for better cache utilization in next task
tmp = t1;
t1 = vt1;
vt1 = tmp;
}
// reset the status for each thread to release the spin lock
strk_in_process[thread_id] = -1;
}
#pragma omp barrier
if (ib < nb) {
axpy2d(ci1+ib, ci1buf, na, nb, blen);
}
#pragma omp critical
axpy2d(ci1+ib, ci1buf, na, nb, blen);
#pragma omp barrier
}
free(ci1buf);
Expand Down

0 comments on commit a73a24a

Please sign in to comment.