Skip to content

Commit

Permalink
Add checks for malloc
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Jan 31, 2024
1 parent 947a011 commit 3030ced
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 13 deletions.
7 changes: 6 additions & 1 deletion pyscf/lib/ao2mo/nr_ao2mo.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/

#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <assert.h>
//#define NDEBUG
Expand Down Expand Up @@ -1223,7 +1224,11 @@ void AO2MOnr_e1_drv(int (*intor)(), void (*fill)(), void (*ftrans)(), int (*fmmm
{
int nao = ao_loc[nbas];
double *eri_ao = malloc(sizeof(double) * nao*nao*nkl*ncomp);
assert(eri_ao);
if (eri_ao == NULL) {
fprintf(stderr, "malloc(%zu) falied in AO2MOnr_e1_drv\n",
sizeof(double) * nao*nao*nkl*ncomp);
exit(1);
}
AO2MOnr_e1fill_drv(intor, fill, eri_ao, klsh_start, klsh_count,
nkl, ncomp, ao_loc, cintopt, vhfopt,
atm, natm, bas, nbas, env);
Expand Down
7 changes: 6 additions & 1 deletion pyscf/lib/ao2mo/nrr_ao2mo.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/

#include <stdlib.h>
#include <stdio.h>
#include <complex.h>
#include <math.h>
#include <assert.h>
Expand Down Expand Up @@ -237,7 +238,11 @@ void AO2MOnrr_e1_drv(int (*intor)(), void (*fill)(),


double *eri_ao = malloc(sizeof(double)* nao*nao*nkl*ncomp);
assert(eri_ao);
if (eri_ao == NULL) {
fprintf(stderr, "malloc(%zu) falied in AO2MOnrr_e1_drv\n",
sizeof(double) * nao*nao*nkl*ncomp);
exit(1);
}
int ish, kl;
int (*fprescreen)();
if (vhfopt) {
Expand Down
7 changes: 6 additions & 1 deletion pyscf/lib/ao2mo/r_ao2mo.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/

#include <stdlib.h>
#include <stdio.h>
#include <complex.h>
#include <math.h>
#include <assert.h>
Expand Down Expand Up @@ -833,7 +834,11 @@ void AO2MOr_e1_drv(int (*intor)(), void (*fill)(),

double complex *eri_ao = malloc(sizeof(double complex)
* nao*nao*nkl*ncomp);
assert(eri_ao);
if (eri_ao == NULL) {
fprintf(stderr, "malloc(%zu) falied in AO2MOr_e1_drv\n",
sizeof(double complex) * nao*nao*nkl*ncomp);
exit(1);
}
int ish, kl;
int (*fprescreen)();
if (vhfopt != NULL) {
Expand Down
43 changes: 37 additions & 6 deletions pyscf/lib/cc/ccsd_t.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/

#include <stdlib.h>
#include <stdio.h>
#include <complex.h>
#include "config.h"
#include "np_helper/np_helper.h"
Expand Down Expand Up @@ -393,11 +394,16 @@ void CCsd_t_contract(double *e_tot,
_make_permute_indices(permute_idx, nocc);
#pragma omp parallel default(none) \
shared(njobs, nocc, nvir, mo_energy, t1T, t2T, nirrep, o_ir_loc, \
v_ir_loc, oo_ir_loc, orbsym, vooo, fvo, jobs, e_tot, permute_idx)
v_ir_loc, oo_ir_loc, orbsym, vooo, fvo, jobs, e_tot, permute_idx, stderr)
{
int a, b, c;
size_t k;
double *cache1 = malloc(sizeof(double) * (nocc*nocc*nocc*3+2));
if (cache1 == NULL) {
fprintf(stderr, "malloc(%zu) falied in CCsd_t_contract\n",
sizeof(double) * nocc*nocc*nocc*3);
exit(1);
}
double *t1Thalf = malloc(sizeof(double) * nvir*nocc * 2);
double *fvohalf = t1Thalf + nvir*nocc;
for (k = 0; k < nvir*nocc; k++) {
Expand Down Expand Up @@ -443,11 +449,16 @@ void QCIsd_t_contract(double *e_tot,
_make_permute_indices(permute_idx, nocc);
#pragma omp parallel default(none) \
shared(njobs, nocc, nvir, mo_energy, t1T, t2T, nirrep, o_ir_loc, \
v_ir_loc, oo_ir_loc, orbsym, vooo, fvo, jobs, e_tot, permute_idx)
v_ir_loc, oo_ir_loc, orbsym, vooo, fvo, jobs, e_tot, permute_idx, stderr)
{
int a, b, c;
size_t k;
double *cache1 = malloc(sizeof(double) * (nocc*nocc*nocc*3+2));
if (cache1 == NULL) {
fprintf(stderr, "malloc(%zu) falied in QCIsd_t_contract\n",
sizeof(double) * nocc*nocc*nocc*3);
exit(1);
}
double *t1Thalf = malloc(sizeof(double) * nvir*nocc * 2);
double *fvohalf = t1Thalf + nvir*nocc;
for (k = 0; k < nvir*nocc; k++) {
Expand Down Expand Up @@ -619,11 +630,16 @@ void CCsd_t_zcontract(double complex *e_tot,

#pragma omp parallel default(none) \
shared(njobs, nocc, nvir, mo_energy, t1T, t2T, nirrep, o_ir_loc, \
v_ir_loc, oo_ir_loc, orbsym, vooo, fvo, jobs, e_tot, permute_idx)
v_ir_loc, oo_ir_loc, orbsym, vooo, fvo, jobs, e_tot, permute_idx, stderr)
{
int a, b, c;
size_t k;
double complex *cache1 = malloc(sizeof(double complex) * (nocc*nocc*nocc*3+2));
if (cache1 == NULL) {
fprintf(stderr, "malloc(%zu) falied in CCsd_t_zcontract\n",
sizeof(double complex) * nocc*nocc*nocc*3);
exit(1);
}
double complex *t1Thalf = malloc(sizeof(double complex) * nvir*nocc * 2);
double complex *fvohalf = t1Thalf + nvir*nocc;
for (k = 0; k < nvir*nocc; k++) {
Expand Down Expand Up @@ -672,11 +688,16 @@ void QCIsd_t_zcontract(double complex *e_tot,

#pragma omp parallel default(none) \
shared(njobs, nocc, nvir, mo_energy, t1T, t2T, nirrep, o_ir_loc, \
v_ir_loc, oo_ir_loc, orbsym, vooo, fvo, jobs, e_tot, permute_idx)
v_ir_loc, oo_ir_loc, orbsym, vooo, fvo, jobs, e_tot, permute_idx, stderr)
{
int a, b, c;
size_t k;
double complex *cache1 = malloc(sizeof(double complex) * (nocc*nocc*nocc*3+2));
if (cache1 == NULL) {
fprintf(stderr, "malloc(%zu) falied in QCIsd_t_zcontract\n",
sizeof(double complex) * nocc*nocc*nocc*3);
exit(1);
}
double complex *t1Thalf = malloc(sizeof(double complex) * nvir*nocc * 2);
double complex *fvohalf = t1Thalf + nvir*nocc;
for (k = 0; k < nvir*nocc; k++) {
Expand Down Expand Up @@ -853,11 +874,16 @@ void MPICCsd_t_contract(double *e_tot, double *mo_energy, double *t1T,

#pragma omp parallel default(none) \
shared(njobs, nocc, nvir, mo_energy, t1T, fvo, jobs, e_tot, slices, \
data_ptrs, permute_idx)
data_ptrs, permute_idx, stderr)
{
int a, b, c;
size_t k;
double *cache1 = malloc(sizeof(double) * (nocc*nocc*nocc*3+2));
if (cache1 == NULL) {
fprintf(stderr, "malloc(%zu) falied in MPICCsd_t_contract\n",
sizeof(double) * nocc*nocc*nocc*3);
exit(1);
}
double *t1Thalf = malloc(sizeof(double) * nvir*nocc * 2);
double *fvohalf = t1Thalf + nvir*nocc;
for (k = 0; k < nvir*nocc; k++) {
Expand Down Expand Up @@ -1081,11 +1107,16 @@ void CCsd_zcontract_t3T(double complex *t3Tw, double complex *t3Tv, double *mo_e

#pragma omp parallel default(none) \
shared(njobs, nocc, nvir, nkpts, t3Tw, t3Tv, mo_offset, mo_energy, t1T, fvo, jobs, slices, \
data_ptrs, permute_idx)
data_ptrs, permute_idx, stderr)
{
int a, b, c;
size_t k;
complex double *cache1 = malloc(sizeof(double complex) * (nocc*nocc*nocc*3+2));
if (cache1 == NULL) {
fprintf(stderr, "malloc(%zu) falied in CCsd_zcontract_t3T\n",
sizeof(double complex) * nocc*nocc*nocc*3);
exit(1);
}
complex double *t1Thalf = malloc(sizeof(double complex) * nkpts*nvir*nocc*2);
complex double *fvohalf = t1Thalf + nkpts*nvir*nocc;
for (k = 0; k < nkpts*nvir*nocc; k++) {
Expand Down
29 changes: 25 additions & 4 deletions pyscf/lib/cc/uccsd_t.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/

#include <stdlib.h>
#include <stdio.h>
#include "config.h"
#include "np_helper/np_helper.h"
#include "vhf/fblas.h"
Expand Down Expand Up @@ -301,11 +302,16 @@ void CCuccsd_t_aaa(double complex *e_tot,
#pragma omp parallel default(none) \
shared(njobs, nocc, nvir, mo_energy, t1T, t2T, nirrep, o_ir_loc, \
v_ir_loc, oo_ir_loc, orbsym, vooo, fvohalf, jobs, e_tot, \
permute_idx)
permute_idx, stderr)
{
int a, b, c;
size_t k;
double *cache1 = malloc(sizeof(double) * (nocc*nocc*nocc*3+2));
if (cache1 == NULL) {
fprintf(stderr, "malloc(%zu) falied in CCuccsd_t_aaa\n",
sizeof(double) * nocc*nocc*nocc*3);
exit(1);
}
double e = 0;
#pragma omp for schedule (dynamic, 4)
for (k = 0; k < njobs; k++) {
Expand Down Expand Up @@ -544,12 +550,17 @@ void CCuccsd_t_baa(double complex *e_tot,
t1aT, t1bT, t2aaT, t2abT};

#pragma omp parallel default(none) \
shared(njobs, nocca, noccb, nvira, nvirb, vs_ts, jobs, e_tot)
shared(njobs, nocca, noccb, nvira, nvirb, vs_ts, jobs, e_tot, stderr)
{
int a, b, c;
size_t k;
double *cache1 = malloc(sizeof(double) * (noccb*nocca*nocca*5+1 +
nocca*2+noccb*2));
if (cache1 == NULL) {
fprintf(stderr, "malloc(%zu) falied in CCuccsd_t_baa\n",
sizeof(double) * noccb*nocca*nocca*5);
exit(1);
}
double e = 0;
#pragma omp for schedule (dynamic, 4)
for (k = 0; k < njobs; k++) {
Expand Down Expand Up @@ -697,12 +708,17 @@ void CCuccsd_t_zaaa(double complex *e_tot,
#pragma omp parallel default(none) \
shared(njobs, nocc, nvir, mo_energy, t1T, t2T, nirrep, o_ir_loc, \
v_ir_loc, oo_ir_loc, orbsym, vooo, fvohalf, jobs, e_tot, \
permute_idx)
permute_idx, stderr)
{
int a, b, c;
size_t k;
double complex *cache1 = malloc(sizeof(double complex) *
(nocc*nocc*nocc*3+2));
if (cache1 == NULL) {
fprintf(stderr, "malloc(%zu) falied in CCuccsd_t_zaaa\n",
sizeof(double complex) * nocc*nocc*nocc*3);
exit(1);
}
double complex e = 0;
#pragma omp for schedule (dynamic, 4)
for (k = 0; k < njobs; k++) {
Expand Down Expand Up @@ -900,13 +916,18 @@ void CCuccsd_t_zbaa(double complex *e_tot,
t1aT, t1bT, t2aaT, t2abT};

#pragma omp parallel default(none) \
shared(njobs, nocca, noccb, nvira, nvirb, vs_ts, jobs, e_tot)
shared(njobs, nocca, noccb, nvira, nvirb, vs_ts, jobs, e_tot, stderr)
{
int a, b, c;
size_t k;
double complex *cache1 = malloc(sizeof(double complex) *
(noccb*nocca*nocca*5+1 +
nocca*2+noccb*2));
if (cache1 == NULL) {
fprintf(stderr, "malloc(%zu) falied in CCuccsd_t_zbaa\n",
sizeof(double complex) * noccb*nocca*nocca*5);
exit(1);
}
double complex e = 0;
#pragma omp for schedule (dynamic, 4)
for (k = 0; k < njobs; k++) {
Expand Down
5 changes: 5 additions & 0 deletions pyscf/lib/gto/fill_r_4c.c
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ void GTOr4c_drv(int (*intor)(), void (*fill)(), int (*prescreen)(),
{
int ish, jsh, ij;
double *buf = malloc(sizeof(double) * cache_size);
if (buf == NULL) {
fprintf(stderr, "malloc(%zu) falied in GTOr4c_drv\n",
sizeof(double) * cache_size);
exit(1);
}
#pragma omp for schedule(dynamic)
for (ij = 0; ij < nish*njsh; ij++) {
ish = ij / njsh;
Expand Down
5 changes: 5 additions & 0 deletions pyscf/lib/vhf/nr_sgx_direct.c
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,11 @@ void SGXsetnr_direct_scf_dm(CVHFOpt *opt, double *dm, int nset, int *ao_loc,
free(opt->dm_cond);
}
opt->dm_cond = (double *)malloc(sizeof(double) * nbas*ngrids);
if (opt->dm_cond == NULL) {
fprintf(stderr, "malloc(%zu) falied in SGXsetnr_direct_scf_dm\n",
sizeof(double) * nbas*ngrids);
exit(1);
}
// nbas in the input arguments may different to opt->nbas.
// Use opt->nbas because it is used in the prescreen function
memset(opt->dm_cond, 0, sizeof(double)*nbas*ngrids);
Expand Down
10 changes: 10 additions & 0 deletions pyscf/lib/vhf/optimizer.c
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,11 @@ void CVHFsetnr_direct_scf(CVHFOpt *opt, int (*intor)(), CINTOpt *cintopt,
// Use opt->nbas because it is used in the prescreen function
nbas = opt->nbas;
opt->q_cond = (double *)malloc(sizeof(double) * nbas*nbas);
if (opt->q_cond == NULL) {
fprintf(stderr, "malloc(%zu) falied in CVHFsetnr_direct_scf\n",
sizeof(double) * nbas*nbas);
exit(1);
}
CVHFset_int2e_q_cond(intor, cintopt, opt->q_cond, ao_loc,
atm, natm, bas, nbas, env);
}
Expand Down Expand Up @@ -522,6 +527,11 @@ void CVHFsetnr_direct_scf_dm(CVHFOpt *opt, double *dm, int nset, int *ao_loc,
// Use opt->nbas because it is used in the prescreen function
nbas = opt->nbas;
opt->dm_cond = (double *)malloc(sizeof(double) * nbas*nbas);
if (opt->dm_cond == NULL) {
fprintf(stderr, "malloc(%zu) falied in CVHFsetnr_direct_scf_dm\n",
sizeof(double) * nbas*nbas);
exit(1);
}
CVHFnr_dm_cond(opt->dm_cond, dm, nset, ao_loc, atm, natm, bas, nbas, env);
}

Expand Down

0 comments on commit 3030ced

Please sign in to comment.