Skip to content

Commit

Permalink
[feature] new split index implementation (#862)
Browse files Browse the repository at this point in the history
new split index implementation
  • Loading branch information
toxa81 authored Jul 13, 2023
1 parent 52bf38b commit 687ccaa
Show file tree
Hide file tree
Showing 59 changed files with 898 additions and 961 deletions.
4 changes: 2 additions & 2 deletions apps/tests/test_allgather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ void test_allgather()
int N = 11;
std::vector<double> vec(N, 0.0);

sddk::splindex<sddk::splindex_t::block> spl(N, mpi::Communicator::world().size(), mpi::Communicator::world().rank());
sddk::splindex_block<> spl(N, n_blocks(mpi::Communicator::world().size()), block_id(mpi::Communicator::world().rank()));

for (int i = 0; i < spl.local_size(); i++) {
vec[spl[i]] = mpi::Communicator::world().rank() + 1.0;
vec[spl.global_index(i)] = mpi::Communicator::world().rank() + 1.0;
}

{
Expand Down
7 changes: 3 additions & 4 deletions apps/tests/test_wf_ortho.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ void test_wf_ortho(BLACS_grid const& blacs_grid__, double cutoff__, int num_band
for (int igloc = 0; igloc < gvec->count(); igloc++) {
phi.pw_coeffs(igloc, s, wf::band_index(i)) = utils::random<std::complex<T>>();
}
for (int ialoc = 0; ialoc < phi.spl_num_atoms().local_size(); ialoc++) {
int ia = phi.spl_num_atoms()[ialoc];
for (int xi = 0; xi < num_mt_coeffs[ia]; xi++) {
phi.mt_coeffs(xi, wf::atom_index(ialoc), s, wf::band_index(i)) = utils::random<std::complex<T>>();
for (auto it : phi.spl_num_atoms()) {
for (int xi = 0; xi < num_mt_coeffs[it.i]; xi++) {
phi.mt_coeffs(xi, it.li, s, wf::band_index(i)) = utils::random<std::complex<T>>();
}
}
}
Expand Down
16 changes: 7 additions & 9 deletions apps/tests/test_wf_trans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ void test_wf_trans(la::BLACS_grid const& blacs_grid__, double cutoff__, int num_
for (int igloc = 0; igloc < gvec->count(); igloc++) {
phi.pw_coeffs(igloc, s, wf::band_index(i)) = utils::random<std::complex<T>>();
}
for (int ialoc = 0; ialoc < phi.spl_num_atoms().local_size(); ialoc++) {
int ia = phi.spl_num_atoms()[ialoc];
for (int xi = 0; xi < num_mt_coeffs[ia]; xi++) {
phi.mt_coeffs(xi, wf::atom_index(ialoc), s, wf::band_index(i)) = utils::random<std::complex<T>>();
for (auto it : phi.spl_num_atoms()) {
for (int xi = 0; xi < num_mt_coeffs[it.i]; xi++) {
phi.mt_coeffs(xi, it.li, s, wf::band_index(i)) = utils::random<std::complex<T>>();
}
}
}
Expand Down Expand Up @@ -86,12 +85,11 @@ void test_wf_trans(la::BLACS_grid const& blacs_grid__, double cutoff__, int num_
phi.pw_coeffs(igloc, s, wf::band_index(i)) -
psi.pw_coeffs(igloc, s, wf::band_index(num_bands__ - i - 1)));
}
for (int ialoc = 0; ialoc < phi.spl_num_atoms().local_size(); ialoc++) {
int ia = phi.spl_num_atoms()[ialoc];
for (int xi = 0; xi < num_mt_coeffs[ia]; xi++) {
for (auto it : phi.spl_num_atoms()) {
for (int xi = 0; xi < num_mt_coeffs[it.i]; xi++) {
diff += std::abs(
phi.mt_coeffs(xi, wf::atom_index(ialoc), s, wf::band_index(i)) -
psi.mt_coeffs(xi, wf::atom_index(ialoc), s, wf::band_index(num_bands__ - i - 1)));
phi.mt_coeffs(xi, it.li, s, wf::band_index(i)) -
psi.mt_coeffs(xi, it.li, s, wf::band_index(num_bands__ - i - 1)));
}
}
}
Expand Down
44 changes: 25 additions & 19 deletions apps/unit_tests/test_splindex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ int test1()
{
for (int num_ranks = 1; num_ranks < 20; num_ranks++) {
for (int N = 1; N < 1130; N++) {
splindex<splindex_t::block> spl(N, num_ranks, 0);
splindex_block<> spl(N, n_blocks(num_ranks), block_id(0));
int sz = 0;
for (int i = 0; i < num_ranks; i++) {
sz += (int)spl.local_size(i);
sz += spl.local_size(block_id(i));
}
if (sz != N) {
std::stringstream s;
Expand All @@ -20,21 +20,21 @@ int test1()
s << "computed global index size: " << sz << std::endl;
s << "number of ranks: " << num_ranks << std::endl;
for (int i = 0; i < num_ranks; i++) {
s << "i, local_size(i): " << i << ", " << spl.local_size(i) << std::endl;
s << "i, local_size(i): " << i << ", " << spl.local_size(block_id(i)) << std::endl;
}
throw std::runtime_error(s.str());
}
for (int i = 0; i < N; i++) {
int rank = spl.local_rank(i);
int offset = (int)spl.local_index(i);
if (i != (int)spl.global_index(offset, rank)) {
int rank = spl.location(block_id(i)).ib;
int offset = spl.location(block_id(i)).index_local;
if (i != (int)spl.global_index(offset, block_id(rank))) {
std::stringstream s;
s << "test1: wrong index." << std::endl;
s << "global index size: " << N << std::endl;
s << "number of ranks: " << num_ranks << std::endl;
s << "global index: " << i << std::endl;
s << "rank, offset: " << rank << ", " << offset << std::endl;
s << "computed global index: " << spl.global_index(offset, rank) << std::endl;
s << "computed global index: " << spl.global_index(offset, block_id(rank)) << std::endl;
throw std::runtime_error(s.str());
}
}
Expand All @@ -48,22 +48,28 @@ int test2()
for (int bs = 1; bs < 17; bs++) {
for (int num_ranks = 1; num_ranks < 13; num_ranks++) {
for (int N = 1; N < 1113; N++) {
splindex<splindex_t::block_cyclic> spl(N, num_ranks, 0, bs);
sddk::splindex_block_cyclic<> spl(N, n_blocks(num_ranks), block_id(0), bs);
int sz = 0;
for (int i = 0; i < num_ranks; i++) {
sz += (int)spl.local_size(i);
sz += (int)spl.local_size(block_id(i));
}
if (sz != N) {
std::stringstream s;

s << "test2: wrong sum of local sizes" << std::endl;
s << "test2: wrong sum of local sizes" << std::endl
<< "N : " << N << std::endl
<< "num_ranks :" << num_ranks << std::endl
<< "block size : " << bs << std::endl;
for (int i = 0; i < num_ranks; i++) {
s << "rank, local_size : " << i << ", " << spl.local_size(block_id(i)) << std::endl;
}
throw std::runtime_error(s.str());
}

for (int i = 0; i < N; i++) {
int rank = spl.local_rank(i);
int offset = (int)spl.local_index(i);
if (i != (int)spl.global_index(offset, rank)) {
int rank = spl.location(block_id(i)).ib;
int offset = spl.location(block_id(i)).index_local;
if (i != (int)spl.global_index(offset, block_id(rank))) {
std::stringstream s;
s << "test2: wrong index" << std::endl;
s << "bs = " << bs << std::endl
Expand All @@ -72,7 +78,7 @@ int test2()
<< "idx = " << i << std::endl
<< "rank = " << rank << std::endl
<< "offset = " << offset << std::endl
<< "computed index = " << spl.global_index(offset, rank) << std::endl;
<< "computed index = " << spl.global_index(offset, block_id(rank)) << std::endl;
throw std::runtime_error(s.str());
}
}
Expand All @@ -86,14 +92,14 @@ int test3()
{
for (int num_ranks = 1; num_ranks < 20; num_ranks++) {
for (int N = 1; N < 1130; N++) {
splindex<splindex_t::block> spl_tmp(N, num_ranks, 0);
splindex_block<> spl_tmp(N, n_blocks(num_ranks), block_id(0));

splindex<splindex_t::chunk> spl(N, num_ranks, 0, spl_tmp.counts());
splindex_chunk<> spl(N, n_blocks(num_ranks), block_id(0), spl_tmp.counts());

for (int i = 0; i < N; i++) {
int rank = spl.local_rank(i);
int offset = spl.local_index(i);
if (i != spl.global_index(offset, rank)) {
int rank = spl.location(block_id(i)).ib;
int offset = spl.location(block_id(i)).index_local;
if (i != spl.global_index(offset, block_id(rank))) {
std::stringstream s;
s << "test3: wrong index" << std::endl;
throw std::runtime_error(s.str());
Expand Down
2 changes: 1 addition & 1 deletion python_module/py_sirius.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ PYBIND11_MODULE(py_sirius, m)
[](K_point_set& ks, int i) -> K_point<double>& {
if (i >= ks.spl_num_kpoints().local_size())
throw pybind11::index_error("out of bounds");
return *ks.get<double>(ks.spl_num_kpoints(i));
return *ks.get<double>(ks.spl_num_kpoints().global_index(typename kp_index_t::local(i)));
},
py::return_value_policy::reference_internal)
.def("__len__", [](K_point_set const& ks) { return ks.spl_num_kpoints().local_size(); })
Expand Down
Loading

0 comments on commit 687ccaa

Please sign in to comment.