Skip to content

Commit

Permalink
Started TREXIO wfn IO, not fully tested
Browse files Browse the repository at this point in the history
  • Loading branch information
David Williams-Young committed Nov 6, 2023
1 parent a9a0208 commit f249eee
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 4 deletions.
2 changes: 1 addition & 1 deletion include/macis/asci/determinant_search.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ std::vector<wfn_t<N>> asci_search(
dist_quickselect(scores.begin(), scores.end(), top_k_elements, comm,
std::greater<double>{}, std::equal_to<double>{});

logger->info(" * Kth Score Pivot = {.2e}", kth_score);
logger->info(" * Kth Score Pivot = {:.2e}", kth_score);
// Partition local pairs into less / eq batches
auto [g_begin, e_begin, l_begin, _end] = leg_partition(
asci_pairs.begin(), asci_pairs.end(), kth_score,
Expand Down
4 changes: 4 additions & 0 deletions include/macis/util/trexio.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@ class TREXIOFile {
double read_nucleus_repulsion() const;
void read_mo_1e_int_core_hamiltonian(double* h) const;
void read_mo_2e_int_eri(double* V) const;
int64_t read_determinant_num() const;
int32_t get_determinant_int64_num() const;
void read_determinant_list(int64_t ndet, int64_t* dets, int64_t ioff = 0) const;

void write_mo_num(int64_t nmo);
void write_nucleus_repulsion(double E);
void write_mo_1e_int_core_hamiltonian(const double* h);
void write_mo_2e_int_eri(const double* V);
void write_determinant_list(int64_t ndet, const int64_t* dets, int64_t ioff = 0);

};

Expand Down
28 changes: 28 additions & 0 deletions src/macis/trexio.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,28 @@ void TREXIOFile::read_mo_2e_int_eri(double* V) const {
}


int64_t TREXIOFile::read_determinant_num() const {
int64_t ndet;
auto rc = trexio_read_determinant_num_64(file_handle_, &ndet);
if(rc != TREXIO_SUCCESS) TREXIO_EXCEPTION(rc);
return ndet;
}

int32_t TREXIOFile::get_determinant_int64_num() const {
int32_t n64;
auto rc = trexio_get_int64_num(file_handle_, &n64);
if(rc != TREXIO_SUCCESS) TREXIO_EXCEPTION(rc);
return n64;
}

void TREXIOFile::read_determinant_list(int64_t ndet, int64_t* dets, int64_t ioff) const {
int64_t icount = ndet;
auto rc = trexio_read_determinant_list(file_handle_, ioff, &icount, dets);
if(rc != TREXIO_SUCCESS) TREXIO_EXCEPTION(rc);
}






Expand Down Expand Up @@ -143,4 +165,10 @@ void TREXIOFile::write_mo_2e_int_eri(const double* V) {
if(rc != TREXIO_SUCCESS) TREXIO_EXCEPTION(rc);
}

void TREXIOFile::write_determinant_list(int64_t ndet, const int64_t* dets, int64_t ioff) {
int64_t icount = ndet;
auto rc = trexio_write_determinant_list(file_handle_, ioff, icount, dets);
if(rc != TREXIO_SUCCESS) TREXIO_EXCEPTION(rc);
}

}
14 changes: 11 additions & 3 deletions tests/standalone_driver.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ int main(int argc, char** argv) {
macis::ASCISettings asci_settings;
std::string asci_wfn_fname, asci_wfn_out_fname;
double asci_E0 = 0.0;
bool compute_asci_E0 = true;
bool compute_asci_E0 = true, pt2 = true;
OPT_KEYWORD("ASCI.NTDETS_MAX", asci_settings.ntdets_max, size_t);
OPT_KEYWORD("ASCI.NTDETS_MIN", asci_settings.ntdets_min, size_t);
OPT_KEYWORD("ASCI.NCDETS_MAX", asci_settings.ncdets_max, size_t);
Expand All @@ -216,6 +216,7 @@ int main(int argc, char** argv) {
asci_E0 = input.getData<double>("ASCI.E0_WFN");
compute_asci_E0 = false;
}
OPT_KEYWORD("ASCI.PT2", pt2, bool);

bool mp2_guess = false;
OPT_KEYWORD("MCSCF.MP2_GUESS", mp2_guess, bool);
Expand Down Expand Up @@ -307,10 +308,8 @@ int main(int argc, char** argv) {
std::vector<double> active_ordm(n_active * n_active);
std::vector<double> active_trdm(active_ordm.size() * active_ordm.size());

bool pt2 = true;
double E0 = 0;
double EPT2 = 0;

// CI
if(job == Job::CI) {
using generator_t = macis::SortedDoubleLoopHamiltonianGenerator<wfn_type>;
Expand Down Expand Up @@ -401,7 +400,16 @@ int main(int argc, char** argv) {

if(asci_wfn_out_fname.size() and !world_rank) {
console->info("Writing ASCI Wavefunction to {}", asci_wfn_out_fname);
//if(reference_data_format == "TREXIO") {
// console->info(" * Format TREXIO");
// macis::TREXIOFile trexio_file(asci_wfn_out_fname, 'w', TREXIO_HDF5);
// trexio_file.write_mo_num(nwfn_bits/2); // Trick TREXIO
// trexio_file.write_determinant_list(dets.size(),
// reinterpret_cast<int64_t*>(dets.data()));
//} else {
console->info(" * Format TEXT");
macis::write_wavefunction(asci_wfn_out_fname, n_active, dets, C);
//}
}

// Dump Hamiltonian
Expand Down

0 comments on commit f249eee

Please sign in to comment.