Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/split index proposal #861

Closed
wants to merge 15 commits into from
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ if(CREATE_FORTRAN_BINDINGS)
enable_language(Fortran)
endif()

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CUDA_STANDARD 14)

if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
Expand Down
4 changes: 2 additions & 2 deletions apps/tests/test_allgather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ void test_allgather()
int N = 11;
std::vector<double> vec(N, 0.0);

sddk::splindex<sddk::splindex_t::block> spl(N, mpi::Communicator::world().size(), mpi::Communicator::world().rank());
sddk::splindex_block<> spl(N, n_blocks(mpi::Communicator::world().size()), block_id(mpi::Communicator::world().rank()));

for (int i = 0; i < spl.local_size(); i++) {
vec[spl[i]] = mpi::Communicator::world().rank() + 1.0;
vec[spl.global_index(i)] = mpi::Communicator::world().rank() + 1.0;
}

{
Expand Down
5 changes: 2 additions & 3 deletions apps/tests/test_wf_ortho.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ void test_wf_ortho(BLACS_grid const& blacs_grid__, double cutoff__, int num_band
for (int igloc = 0; igloc < gvec->count(); igloc++) {
phi.pw_coeffs(igloc, s, wf::band_index(i)) = utils::random<std::complex<T>>();
}
for (int ialoc = 0; ialoc < phi.spl_num_atoms().local_size(); ialoc++) {
int ia = phi.spl_num_atoms()[ialoc];
for (auto [ia, lia] : phi.spl_num_atoms()) {
for (int xi = 0; xi < num_mt_coeffs[ia]; xi++) {
phi.mt_coeffs(xi, wf::atom_index(ialoc), s, wf::band_index(i)) = utils::random<std::complex<T>>();
phi.mt_coeffs(xi, lia, s, wf::band_index(i)) = utils::random<std::complex<T>>();
}
}
}
Expand Down
12 changes: 5 additions & 7 deletions apps/tests/test_wf_trans.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,9 @@ void test_wf_trans(la::BLACS_grid const& blacs_grid__, double cutoff__, int num_
for (int igloc = 0; igloc < gvec->count(); igloc++) {
phi.pw_coeffs(igloc, s, wf::band_index(i)) = utils::random<std::complex<T>>();
}
for (int ialoc = 0; ialoc < phi.spl_num_atoms().local_size(); ialoc++) {
int ia = phi.spl_num_atoms()[ialoc];
for (auto [ia, lia] : phi.spl_num_atoms()) {
for (int xi = 0; xi < num_mt_coeffs[ia]; xi++) {
phi.mt_coeffs(xi, wf::atom_index(ialoc), s, wf::band_index(i)) = utils::random<std::complex<T>>();
phi.mt_coeffs(xi, lia, s, wf::band_index(i)) = utils::random<std::complex<T>>();
}
}
}
Expand Down Expand Up @@ -86,12 +85,11 @@ void test_wf_trans(la::BLACS_grid const& blacs_grid__, double cutoff__, int num_
phi.pw_coeffs(igloc, s, wf::band_index(i)) -
psi.pw_coeffs(igloc, s, wf::band_index(num_bands__ - i - 1)));
}
for (int ialoc = 0; ialoc < phi.spl_num_atoms().local_size(); ialoc++) {
int ia = phi.spl_num_atoms()[ialoc];
for (auto [ia, li] : phi.spl_num_atoms()) {
for (int xi = 0; xi < num_mt_coeffs[ia]; xi++) {
diff += std::abs(
phi.mt_coeffs(xi, wf::atom_index(ialoc), s, wf::band_index(i)) -
psi.mt_coeffs(xi, wf::atom_index(ialoc), s, wf::band_index(num_bands__ - i - 1)));
phi.mt_coeffs(xi, li, s, wf::band_index(i)) -
psi.mt_coeffs(xi, li, s, wf::band_index(num_bands__ - i - 1)));
}
}
}
Expand Down
44 changes: 25 additions & 19 deletions apps/unit_tests/test_splindex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ int test1()
{
for (int num_ranks = 1; num_ranks < 20; num_ranks++) {
for (int N = 1; N < 1130; N++) {
splindex<splindex_t::block> spl(N, num_ranks, 0);
splindex_block<> spl(N, n_blocks(num_ranks), block_id(0));
int sz = 0;
for (int i = 0; i < num_ranks; i++) {
sz += (int)spl.local_size(i);
sz += spl.local_size(block_id(i));
}
if (sz != N) {
std::stringstream s;
Expand All @@ -20,21 +20,21 @@ int test1()
s << "computed global index size: " << sz << std::endl;
s << "number of ranks: " << num_ranks << std::endl;
for (int i = 0; i < num_ranks; i++) {
s << "i, local_size(i): " << i << ", " << spl.local_size(i) << std::endl;
s << "i, local_size(i): " << i << ", " << spl.local_size(block_id(i)) << std::endl;
}
throw std::runtime_error(s.str());
}
for (int i = 0; i < N; i++) {
int rank = spl.local_rank(i);
int offset = (int)spl.local_index(i);
if (i != (int)spl.global_index(offset, rank)) {
int rank = spl.location(block_id(i)).ib;
int offset = spl.location(block_id(i)).index_local;
if (i != (int)spl.global_index(offset, block_id(rank))) {
std::stringstream s;
s << "test1: wrong index." << std::endl;
s << "global index size: " << N << std::endl;
s << "number of ranks: " << num_ranks << std::endl;
s << "global index: " << i << std::endl;
s << "rank, offset: " << rank << ", " << offset << std::endl;
s << "computed global index: " << spl.global_index(offset, rank) << std::endl;
s << "computed global index: " << spl.global_index(offset, block_id(rank)) << std::endl;
throw std::runtime_error(s.str());
}
}
Expand All @@ -48,22 +48,28 @@ int test2()
for (int bs = 1; bs < 17; bs++) {
for (int num_ranks = 1; num_ranks < 13; num_ranks++) {
for (int N = 1; N < 1113; N++) {
splindex<splindex_t::block_cyclic> spl(N, num_ranks, 0, bs);
sddk::splindex_block_cyclic<> spl(N, n_blocks(num_ranks), block_id(0), bs);
int sz = 0;
for (int i = 0; i < num_ranks; i++) {
sz += (int)spl.local_size(i);
sz += (int)spl.local_size(block_id(i));
}
if (sz != N) {
std::stringstream s;

s << "test2: wrong sum of local sizes" << std::endl;
s << "test2: wrong sum of local sizes" << std::endl
<< "N : " << N << std::endl
<< "num_ranks :" << num_ranks << std::endl
<< "block size : " << bs << std::endl;
for (int i = 0; i < num_ranks; i++) {
s << "rank, local_size : " << i << ", " << spl.local_size(block_id(i)) << std::endl;
}
throw std::runtime_error(s.str());
}

for (int i = 0; i < N; i++) {
int rank = spl.local_rank(i);
int offset = (int)spl.local_index(i);
if (i != (int)spl.global_index(offset, rank)) {
int rank = spl.location(block_id(i)).ib;
int offset = spl.location(block_id(i)).index_local;
if (i != (int)spl.global_index(offset, block_id(rank))) {
std::stringstream s;
s << "test2: wrong index" << std::endl;
s << "bs = " << bs << std::endl
Expand All @@ -72,7 +78,7 @@ int test2()
<< "idx = " << i << std::endl
<< "rank = " << rank << std::endl
<< "offset = " << offset << std::endl
<< "computed index = " << spl.global_index(offset, rank) << std::endl;
<< "computed index = " << spl.global_index(offset, block_id(rank)) << std::endl;
throw std::runtime_error(s.str());
}
}
Expand All @@ -86,14 +92,14 @@ int test3()
{
for (int num_ranks = 1; num_ranks < 20; num_ranks++) {
for (int N = 1; N < 1130; N++) {
splindex<splindex_t::block> spl_tmp(N, num_ranks, 0);
splindex_block<> spl_tmp(N, n_blocks(num_ranks), block_id(0));

splindex<splindex_t::chunk> spl(N, num_ranks, 0, spl_tmp.counts());
splindex_chunk<> spl(N, n_blocks(num_ranks), block_id(0), spl_tmp.counts());

for (int i = 0; i < N; i++) {
int rank = spl.local_rank(i);
int offset = spl.local_index(i);
if (i != spl.global_index(offset, rank)) {
int rank = spl.location(block_id(i)).ib;
int offset = spl.location(block_id(i)).index_local;
if (i != spl.global_index(offset, block_id(rank))) {
std::stringstream s;
s << "test3: wrong index" << std::endl;
throw std::runtime_error(s.str());
Expand Down
2 changes: 1 addition & 1 deletion python_module/py_sirius.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ PYBIND11_MODULE(py_sirius, m)
[](K_point_set& ks, int i) -> K_point<double>& {
if (i >= ks.spl_num_kpoints().local_size())
throw pybind11::index_error("out of bounds");
return *ks.get<double>(ks.spl_num_kpoints(i));
return *ks.get<double>(ks.spl_num_kpoints().global_index(typename kp_index_t::local(i)));
},
py::return_value_policy::reference_internal)
.def("__len__", [](K_point_set const& ks) { return ks.spl_num_kpoints().local_size(); })
Expand Down
Loading
Loading