Skip to content

Commit

Permalink
Refactor dist_constraint_gen into separate generation + distribution …
Browse files Browse the repository at this point in the history
…functions, add MPI-RMA dynamic load balancer based on the NWChem-NXTVAL trick (HT @jeffhammond)
  • Loading branch information
David Williams-Young committed Oct 30, 2023
1 parent 43488f2 commit 3d21260
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 99 deletions.
135 changes: 135 additions & 0 deletions include/macis/asci/mask_constraints.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ auto make_triplet(unsigned i, unsigned j, unsigned k) {
}

#ifdef MACIS_ENABLE_MPI
#if 0
template <typename WfnType, typename ContainerType>
auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
size_t nd_othr,
Expand Down Expand Up @@ -625,6 +626,140 @@ auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,

return constraints;
}
#else
template <typename WfnType, typename ContainerType>
auto gen_constraints_general(size_t nlevels, size_t norb, size_t ns_othr,
size_t nd_othr,
const ContainerType& unique_alpha,
int world_size) {

using wfn_traits = wavefunction_traits<WfnType>;
using constraint_type = alpha_constraint<wfn_traits>;
using string_type = typename constraint_type::constraint_type;

constexpr bool flat_container = std::is_same_v<
std::decay_t<WfnType>,
std::decay_t<typename ContainerType::value_type>
>;

// Generate triplets + heuristic
std::vector<std::pair<constraint_type, size_t>> constraint_sizes;
constraint_sizes.reserve(norb * norb * norb);
size_t total_work = 0;
for(int t_i = 0; t_i < norb; ++t_i)
for(int t_j = 0; t_j < t_i; ++t_j)
for(int t_k = 0; t_k < t_j; ++t_k) {
auto constraint = constraint_type::make_triplet(t_i, t_j, t_k);

size_t nw = 0;
for(const auto& alpha : unique_alpha) {
if constexpr (flat_container)
nw += constraint_histogram(wfn_traits::alpha_string(alpha), ns_othr,
nd_othr, constraint);
else
nw += alpha.second *
constraint_histogram(alpha.first, ns_othr, nd_othr, constraint);
}
if(nw) constraint_sizes.emplace_back(constraint, nw);
total_work += nw;
}

size_t local_average = (0.8 * total_work) / world_size;

for(size_t ilevel = 0; ilevel < nlevels; ++ilevel) {
// Select constraints larger than average to be broken apart
std::vector<std::pair<constraint_type, size_t>> tps_to_next;
{
auto it = std::partition(
constraint_sizes.begin(), constraint_sizes.end(),
[=](const auto& a) { return a.second <= local_average; });

// Remove constraints from full list
tps_to_next = decltype(tps_to_next)(it, constraint_sizes.end());
constraint_sizes.erase(it, constraint_sizes.end());
for(auto [t, s] : tps_to_next) total_work -= s;
}

if(!tps_to_next.size()) break;

// Break apart constraints
for(auto [c, nw_trip] : tps_to_next) {
const auto C_min = c.C_min();

// Loop over possible constraints with one more element
for(auto q_l = 0; q_l < C_min; ++q_l) {
// Generate masks / counts
string_type cn_C = c.C();
cn_C.flip(q_l);
string_type cn_B = c.B() >> (C_min - q_l);
constraint_type c_next(cn_C, cn_B, q_l);

size_t nw = 0;

for(const auto& alpha : unique_alpha) {
if constexpr (flat_container)
nw += constraint_histogram(wfn_traits::alpha_string(alpha), ns_othr,
nd_othr, c_next);
else
nw += alpha.second *
constraint_histogram(alpha.first, ns_othr, nd_othr, c_next);
}
if(nw) constraint_sizes.emplace_back(c_next, nw);
total_work += nw;
}
}
} // Recurse into constraints

// Sort to get optimal bucket partitioning
std::sort(constraint_sizes.begin(), constraint_sizes.end(),
[](const auto& a, const auto& b) { return a.second > b.second; });

return constraint_sizes;
}

template <typename WfnType, typename ContainerType>
auto dist_constraint_general(size_t nlevels, size_t norb, size_t ns_othr,
size_t nd_othr,
const ContainerType& unique_alpha,
MPI_Comm comm) {

using wfn_traits = wavefunction_traits<WfnType>;
using constraint_type = alpha_constraint<wfn_traits>;

auto world_rank = comm_rank(comm);
auto world_size = comm_size(comm);

// Generate constraints subject to expected workload
auto constraint_sizes = gen_constraints_general<WfnType>(nlevels, norb, ns_othr,
nd_othr, unique_alpha, world_size);

// Global workloads
std::vector<size_t> workloads(world_size, 0);

// Assign work
std::vector<constraint_type> constraints;
constraints.reserve(constraint_sizes.size() / world_size);

for(auto [c, nw] : constraint_sizes) {
// Get rank with least amount of work
auto min_rank_it = std::min_element(workloads.begin(), workloads.end());
int min_rank = std::distance(workloads.begin(), min_rank_it);

// Assign constraint
*min_rank_it += nw;
if(world_rank == min_rank) {
constraints.emplace_back(c);
}
}

// if(world_rank == 0)
// printf("[rank %2d] AFTER LOCAL WORK = %lu TOTAL WORK = %lu\n", world_rank,
// workloads[world_rank], total_work);

return constraints;

}
#endif
#endif

#if 0
Expand Down
137 changes: 38 additions & 99 deletions include/macis/asci/pt2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,15 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,

auto gen_c_st = clock_type::now();
//auto constraints = dist_constraint_general<wfn_t<N>>(
// 5, norb, n_sing_beta, n_doub_beta, uniq_alpha_wfn, comm);
auto constraints = dist_constraint_general<wfn_t<N>>(
5, norb, n_sing_beta, n_doub_beta, uniq_alpha, comm);
// 5, norb, n_sing_beta, n_doub_beta, uniq_alpha, comm);
auto constraints = gen_constraints_general<wfn_t<N>>(
5, norb, n_sing_beta, n_doub_beta, uniq_alpha, world_size);
auto gen_c_en = clock_type::now();
duration_type gen_c_dur = gen_c_en - gen_c_st;
logger->info(" * GEN_DUR = {:.2e} ms", gen_c_dur.count());

size_t max_size = std::min(100000000ul,
ncdets * (n_sing_alpha + n_sing_beta + // AA + BB
n_doub_alpha + n_doub_beta + // AAAA + BBBB
n_sing_alpha * n_sing_beta // AABB
));
size_t max_size = 100000000ul;

double EPT2 = 0.0;
size_t NPT2 = 0;
auto pt2_st = clock_type::now();
Expand All @@ -135,102 +132,41 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
}
//std::mutex print_barrier;

const size_t ncon_total = constraints.size();
duration_type lock_wait_dur(0.0);
MPI_Win window;
//MPI_Win_create( &window_count, sizeof(size_t), sizeof(size_t), MPI_INFO_NULL, comm, &window );
size_t* window_buffer;
MPI_Win_allocate( sizeof(size_t), sizeof(size_t), MPI_INFO_NULL, comm, &window_buffer, &window);
if(window == MPI_WIN_NULL) throw std::runtime_error("Window failed");
MPI_Win_lock_all(MPI_MODE_NOCHECK, window);
// Process ASCI pair contributions for each constraint
#pragma omp parallel
//#pragma omp parallel reduction(+ : EPT2) reduction(+ : NPT2)
{
asci_contrib_container<wfn_t<N>> asci_pairs;
asci_pairs.reserve(max_size);
#pragma omp for reduction(+ : EPT2) reduction(+ : NPT2)
for(size_t ic = 0; ic < constraints.size(); ++ic) {
const auto& con = constraints[ic];
if(ic >= print_points.front()) {
//std::lock_guard<std::mutex> lock(print_barrier);
printf("[rank %d] %.1f done\n", world_rank, double(ic)/constraints.size()*100);
print_points.pop_front();
}
//#pragma omp for
//for(size_t ic = 0; ic < constraints.size(); ++ic)
size_t ic = 0;
while(ic < ncon_total)
{
size_t ntake = 10;
MPI_Fetch_and_op(&ntake, &ic, MPI_UINT64_T, 0, 0, MPI_SUM, window);
MPI_Win_flush(0, window);

// Loop over assigned tasks
const size_t c_end = std::min( ncon_total, ic + ntake);
for(; ic < c_end; ++ic ) {

const auto& con = constraints[ic].first;
//if(ic >= print_points.front()) {
// //std::lock_guard<std::mutex> lock(print_barrier);
// printf("[rank %d] %.1f done\n", world_rank, double(ic)/constraints.size()*100);
// print_points.pop_front();
//}
printf("[rank %d] %lu / %lu\n", world_rank, ic, ncon_total);
const double h_el_tol = 1e-16;

#if 0
// Loop over unique alpha strings
for(size_t i_alpha = 0; i_alpha < nuniq_alpha; ++i_alpha) {
const auto& det = uniq_alpha_wfn[i_alpha];
const auto occ_alpha = bits_to_indices(det);

// AA excitations
for(const auto& bcd : uad[i_alpha].bcd) {
const auto& beta = bcd.beta_string;
const auto& coeff = bcd.coeff;
const auto& h_diag = bcd.h_diag;
const auto& occ_beta = bcd.occ_beta;
const auto& orb_ens_alpha = bcd.orb_ens_alpha;
generate_constraint_singles_contributions_ss(
coeff, det | beta, con, occ_alpha, occ_beta, orb_ens_alpha.data(),
T_pq, norb, G_red, norb, V_red, norb, h_el_tol, h_diag, E_ASCI,
ham_gen, asci_pairs);
}

// AAAA excitations
for(const auto& bcd : uad[i_alpha].bcd) {
const auto& beta = bcd.beta_string;
const auto& coeff = bcd.coeff;
const auto& h_diag = bcd.h_diag;
const auto& occ_beta = bcd.occ_beta;
const auto& orb_ens_alpha = bcd.orb_ens_alpha;
generate_constraint_doubles_contributions_ss(
coeff, det | beta, con, occ_alpha, occ_beta, orb_ens_alpha.data(),
G_pqrs, norb, h_el_tol, h_diag, E_ASCI, ham_gen, asci_pairs);
}

// AABB excitations
for(const auto& bcd : uad[i_alpha].bcd) {
const auto& beta = bcd.beta_string;
const auto& coeff = bcd.coeff;
const auto& h_diag = bcd.h_diag;
const auto& occ_beta = bcd.occ_beta;
const auto& vir_beta = bcd.vir_beta;
const auto& orb_ens_alpha = bcd.orb_ens_alpha;
const auto& orb_ens_beta = bcd.orb_ens_beta;
generate_constraint_doubles_contributions_os(
coeff, det | beta, con, occ_alpha, occ_beta, vir_beta,
orb_ens_alpha.data(), orb_ens_beta.data(), V_pqrs, norb, h_el_tol,
h_diag, E_ASCI, ham_gen, asci_pairs);
}

// If the alpha determinant satisfies the constraint,
// append BB and BBBB excitations
if(satisfies_constraint(wfn_traits::alpha_string(det), con)) {
for(const auto& bcd : uad[i_alpha].bcd) {
const auto& beta = bcd.beta_string;
const auto& coeff = bcd.coeff;
const auto& h_diag = bcd.h_diag;
const auto& occ_beta = bcd.occ_beta;
const auto& vir_beta = bcd.vir_beta;
const auto& eps_beta = bcd.orb_ens_beta;

const auto state = det | beta;
const auto state_alpha = wfn_traits::alpha_string(state);
const auto state_beta = wfn_traits::beta_string(beta);
// BB Excitations
append_singles_asci_contributions<Spin::Beta>(
coeff, state, state_beta, occ_beta, vir_beta, occ_alpha,
eps_beta.data(), T_pq, norb, G_red, norb, V_red, norb, h_el_tol,
h_diag, E_ASCI, ham_gen, asci_pairs);

// BBBB Excitations
append_ss_doubles_asci_contributions<Spin::Beta>(
coeff, state, state_beta, state_alpha, occ_beta, vir_beta,
occ_alpha, eps_beta.data(), G_pqrs, norb, h_el_tol, h_diag,
E_ASCI, ham_gen, asci_pairs);

// No excition - to remove for PT2
asci_pairs.push_back(
{state, std::numeric_limits<double>::infinity(), 1.0});
} // Beta Loop
} // Triplet Check

} // Unique Alpha Loop
#else

for(size_t i_alpha = 0, iw = 0; i_alpha < nuniq_alpha; ++i_alpha) {
const auto& alpha_det = uniq_alpha[i_alpha].first;
const auto occ_alpha = bits_to_indices(alpha_det);
Expand Down Expand Up @@ -286,7 +222,6 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,

} // Unique Alpha Loop

#endif

double EPT2_local = 0.0;
// Local S&A for each quad + update EPT2
Expand All @@ -304,6 +239,7 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
}

EPT2 += EPT2_local;
} // Loc constraint loop
} // Constraint Loop
}
auto pt2_en = clock_type::now();
Expand All @@ -320,9 +256,12 @@ double asci_pt2_constraint(wavefunction_iterator_t<N> cdets_begin,
} else {
logger->info("* PT2_DUR = ${:.2e} ms", local_pt2_dur);
}
printf("[rank %d] WAIT_DUR = %.2e\n", world_rank, lock_wait_dur.count());

NPT2 = allreduce(NPT2, MPI_SUM, comm);
logger->info("* NPT2 = {}", NPT2);
MPI_Win_unlock_all(window);
MPI_Win_free(&window);

return EPT2;
}
Expand Down

0 comments on commit 3d21260

Please sign in to comment.