From 50f826938451449a0729512cfe6a243fb110208d Mon Sep 17 00:00:00 2001 From: Joseph Schuchart Date: Wed, 6 Apr 2022 11:29:01 -0400 Subject: [PATCH] Implement simple key ranges for broadcast Problem: broadcasting a range of keys currently requires application to populate the full range into a container and pass it to ttg::broadcast. These keys then have to be processed and send to peers, leading to large messages (in addition to the actual data transfer). Many (dense) applications send a range of keys that can be expressed as a linear function in an N-dimensional space. With key ranges, application can express the start and end of a range as well as a step size for iterating through the range (a vector in the N dimensional space). Some applications may require splitting the range, i.e., providing a list or ranges (e.g., the Floyd-Warshall example). These ranges can be serialized and send to peers instead of the fully populated list of keys. Such compact representation of the key range is cheaper to transfer and store. An example for using key ranges: ``` auto range1_low = ttg::make_keyrange(Key(I, 0, K), /* start */ Key(I, J, K), /* end, i.e., last + 1 */ Key(0, 1, 0)); /* step vector */ ``` It is possible to define a different type for the step vector by defining `Key::difference_type` or by specializing the `ttg::key_diff_type` trait. Signed-off-by: Joseph Schuchart --- examples/floyd-warshall/floyd_warshall_df.cc | 178 +++---- ttg/CMakeLists.txt | 1 + ttg/ttg.h | 1 + ttg/ttg/func.h | 8 +- ttg/ttg/keyrange.h | 209 ++++++++ ttg/ttg/parsec/ttg.h | 488 +++++++++++++++---- ttg/ttg/terminal.h | 85 +++- ttg/ttg/util/bits.h | 84 ++++ ttg/ttg/util/meta.h | 128 +++++ 9 files changed, 964 insertions(+), 218 deletions(-) create mode 100644 ttg/ttg/keyrange.h create mode 100644 ttg/ttg/util/bits.h diff --git a/examples/floyd-warshall/floyd_warshall_df.cc b/examples/floyd-warshall/floyd_warshall_df.cc index 1d523f4fcf..b42c42ee19 100644 --- a/examples/floyd-warshall/floyd_warshall_df.cc +++ b/examples/floyd-warshall/floyd_warshall_df.cc @@ -38,28 +38,37 @@ using namespace ttg; struct Key { // ((I, J), K) where (I, J) is the tile coordinate and K is the iteration number - std::pair, int> execution_info; + //std::pair, int> execution_info; + int I = 0, J = 0, K = 0; bool operator==(const Key& b) const { if (this == &b) return true; - return execution_info.first.first == b.execution_info.first.first && - execution_info.first.second == b.execution_info.first.second && - execution_info.second == b.execution_info.second; + return I == b.I && + J == b.J && + K == b.K; + } + + Key& operator+=(const Key& b) { + I += b.I; + J += b.J; + K += b.K; + rehash(); + return *this; } bool operator!=(const Key& b) const { return !((*this) == b); } madness::hashT hash_val; - Key() : execution_info(std::make_pair(std::make_pair(0, 0), 0)) { rehash(); } - Key(const std::pair, int>& e) : execution_info(e) { rehash(); } - Key(int e_f_f, int e_f_s, int e_s) : execution_info(std::make_pair(std::make_pair(e_f_f, e_f_s), e_s)) { rehash(); } + Key() { rehash(); } + //Key(const std::pair, int>& e) : execution_info(e) { rehash(); } + Key(int e_f_f, int e_f_s, int e_s) : I(e_f_f), J(e_f_s), K(e_s) { rehash(); } madness::hashT hash() const { return hash_val; } void rehash() { std::hash int_hasher; - hash_val = int_hasher(execution_info.first.first) * 2654435769 + int_hasher(execution_info.first.second) * 40503 + - int_hasher(execution_info.second); + hash_val = int_hasher(I) * 2654435769 + int_hasher(J) * 40503 + + int_hasher(K); } #ifdef TTG_SERIALIZATION_SUPPORTS_MADNESS @@ -72,14 +81,16 @@ struct Key { #ifdef TTG_SERIALIZATION_SUPPORTS_BOOST template void serialize(Archive& ar, const unsigned int) { - ar& execution_info; + ar& I; + ar& J; + ar& K; if constexpr (ttg::detail::is_boost_input_archive_v) rehash(); } #endif friend std::ostream& operator<<(std::ostream& out, Key const& k) { - out << "Key((" << k.execution_info.first.first << "," << k.execution_info.first.second << ")," - << k.execution_info.second << ")"; + out << "Key(" << k.I << "," << k.J << "," + << k.K << ")"; return out; } }; @@ -126,13 +137,13 @@ class Initiator : public TT(Key(std::make_pair(std::make_pair(i, j), 0)), kv.second, out); + ::send<0>(Key(i, j, 0), kv.second, out); } else if (i == 0) { // B function call - ::send<1>(Key(std::make_pair(std::make_pair(i, j), 0)), kv.second, out); + ::send<1>(Key(i, j, 0), kv.second, out); } else if (j == 0) { // C function call - ::send<2>(Key(std::make_pair(std::make_pair(i, j), 0)), kv.second, out); + ::send<2>(Key(i, j, 0), kv.second, out); } else { // D function call - ::send<3>(Key(std::make_pair(std::make_pair(i, j), 0)), kv.second, out); + ::send<3>(Key(i, j, 0), kv.second, out); } }); } @@ -177,8 +188,8 @@ class Finalizer : public TT, Finalizer, ttg::typelist &&in -- use direct arguments void op(const Key& key, const std::tuple>&& t, typename baseT::output_terminals_type& out) { - int I = key.execution_info.first.first; - int J = key.execution_info.first.second; + int I = key.I; + int J = key.J; int block_size = problem_size / blocking_factor; BlockMatrix bm = get<0>(t); @@ -253,9 +264,9 @@ class FuncA : public TT>& t, typename baseT::output_terminals_type& out) { - int I = key.execution_info.first.first; - int J = key.execution_info.first.second; - int K = key.execution_info.second; + int I = key.I; + int J = key.J; + int K = key.K; BlockMatrix m_ij; // Executing the update @@ -264,53 +275,35 @@ class FuncA : public TT(t)), (get<0>(t)), (get<0>(t))); // cout << "A[" << I << "," << J << "," << K << "]: " << m_ij << endl; } - // Making u_ready/v_ready for all the B/C function calls in the CURRENT iteration - std::tuple, std::vector, std::vector> bcast_keys; - for (int l = 0; l < blocking_factor; ++l) { - if (l != K) { - /*if (K == 0) { - // B calls - x_ready - ::send<1>(Key(std::make_pair(std::make_pair(I, l), K)), (*adjacency_matrix_ttg)(I,l), out); - // C calls - x_ready - ::send<2>(Key(std::make_pair(std::make_pair(l, J), K)), (*adjacency_matrix_ttg)(l,J), out); - }*/ - // B calls - // cout << "Send " << I << " " << l << " " << K << endl; - //::send<4>(Key(std::make_pair(std::make_pair(I, l), K)), m_ij, out); - std::get<1>(bcast_keys).emplace_back(I, l, K); - - // C calls - // cout << "Send " << l << " " << J << " " << K << endl; - //::send<5>(Key(std::make_pair(std::make_pair(l, J), K)), m_ij, out); - std::get<2>(bcast_keys).emplace_back(l, J, K); - } - } + /* create 2 ranges: 1) (I, l, K) with l = 0..K-1; 2) (I, l, K) with l = K+1..blocking_factor */ + auto range1_low = ttg::make_keyrange(Key(I, 0, K), Key(I, K, K), Key(0, 1, 0)); + auto range1_up = ttg::make_keyrange(Key(I, K+1, K), Key(I, blocking_factor, K), Key(0, 1, 0)); + + /* create 2 ranges: 1) (l, J, K) with l = 0..K-1; 2) (l, J, K) with l = K+1..blocking_factor */ + auto range2_low = ttg::make_keyrange(Key(0, J, K), Key(K, J, K), Key(1, 0, 0)); + auto range2_up = ttg::make_keyrange(Key(K+1, J, K), Key(blocking_factor, J, K), Key(1, 0, 0)); // making x_ready for the computation on the SAME block in the NEXT iteration if (K < (blocking_factor - 1)) { // if there is a NEXT iteration - std::get<0>(bcast_keys)[0] = {I, J, K+1}; + Key key0 = {I, J, K+1}; + auto bcast_keys = std::make_tuple(key0, + std::array{{range1_low, range1_up}}, + std::array{{range2_low, range2_up}}); if (I == K + 1 && J == K + 1) { // in the next iteration, we have A function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<0>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<0, 4, 5>(bcast_keys, std::move(m_ij), out); } else if (I == K + 1) { // in the next iteration, we have B function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<1>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<1, 4, 5>(bcast_keys, std::move(m_ij), out); } else if (J == K + 1) { // in the next iteration, we have C function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<2>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<2, 4, 5>(bcast_keys, std::move(m_ij), out); } else { // in the next iteration, we have D function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<3>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<3, 4, 5>(bcast_keys, std::move(m_ij), out); } } else { - std::get<0>(bcast_keys)[0] = {I, J, K}; - // cout << "A[" << I << "," << J << "," << K << "]: " << m_ij << endl; - //::send<6>(Key(std::make_pair(std::make_pair(I, J), K)), m_ij, out); + Key key0 = {I, J, K}; + auto bcast_keys = std::make_tuple(key0, + std::array{{range1_low, range1_up}}, + std::array{{range2_low, range2_up}}); ::broadcast<6, 4, 5>(bcast_keys, std::move(m_ij), out); } } @@ -353,9 +346,9 @@ class FuncB : public TT, const BlockMatrix>& t, typename baseT::output_terminals_type& out) { - int I = key.execution_info.first.first; - int J = key.execution_info.first.second; - int K = key.execution_info.second; + int I = key.I; + int J = key.J; + int K = key.K; BlockMatrix m_ij; // Executing the update @@ -366,20 +359,14 @@ class FuncB : public TT, std::vector> bcast_keys; - for (int i = 0; i < blocking_factor; ++i) { - if (i != I) { - // if (K == 0) - //::send<3>(Key(std::make_pair(std::make_pair(i, J), K)), (*adjacency_matrix_ttg)(i,J), out); - // cout << "Send " << i << " " << J << " " << K << endl; - //::send<4>(Key(std::make_pair(std::make_pair(i, J), K)), m_ij, out); - std::get<1>(bcast_keys).emplace_back(i, J, K); - } - } + /* create 2 ranges: 1) (I, l, K) with l = 0..K-1; 2) (I, l, K) with l = K+1..blocking_factor */ + auto range1_low = ttg::make_keyrange(Key(0, J, K), Key(I, J, K), Key(1, 0, 0)); + auto range1_up = ttg::make_keyrange(Key(I+1, J, K), Key(blocking_factor, J, K), Key(1, 0, 0)); // making x_ready for the computation on the SAME block in the NEXT iteration if (K < (blocking_factor - 1)) { // if there is a NEXT iteration - std::get<0>(bcast_keys)[0] = {I, J, K+1}; + Key key0 = {I, J, K+1}; + auto bcast_keys = std::make_tuple(key0, std::array{{range1_low, range1_up}}); if (I == K + 1 && J == K + 1) { // in the next iteration, we have A function call // cout << "Send " << I << " " << J << " " << K << endl; //::send<0>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); @@ -398,7 +385,8 @@ class FuncB : public TT(bcast_keys, std::move(m_ij), out); } } else { - std::get<0>(bcast_keys)[0] = {I, J, K}; + Key key0 = {I, J, K}; + auto bcast_keys = std::make_tuple(key0, std::array{{range1_low, range1_up}}); // cout << "B[" << I << "," << J << "," << K << "]: " << m_ij << endl; //::send<5>(Key(std::make_pair(std::make_pair(I, J), K)), m_ij, out); ::broadcast<5, 4>(bcast_keys, std::move(m_ij), out); @@ -443,9 +431,9 @@ class FuncC : public TT, BlockMatrix>& t, typename baseT::output_terminals_type& out) { - int I = key.execution_info.first.first; - int J = key.execution_info.first.second; - int K = key.execution_info.second; + int I = key.I; + int J = key.J; + int K = key.K; BlockMatrix m_ij; // Executing the update @@ -456,40 +444,26 @@ class FuncC : public TT, std::vector> bcast_keys; - for (int j = 0; j < blocking_factor; ++j) { - if (j != J) { - //::send<4>(Key(std::make_pair(std::make_pair(I, j), K)), (*adjacency_matrix_ttg)(I,j), out); - // cout << "Send " << I << " " << j << " " << K << endl; - //::send<4>(Key(std::make_pair(std::make_pair(I, j), K)), m_ij, out); - std::get<1>(bcast_keys).emplace_back(I, j, K); - } - } + auto range1_low = ttg::make_keyrange(Key(I, 0, K), Key(I, J, K), Key(0, 1, 0)); + auto range1_up = ttg::make_keyrange(Key(I, J+1, K), Key(I, blocking_factor, K), Key(0, 1, 0)); // making x_ready for the computation on the SAME block in the NEXT iteration if (K < (blocking_factor - 1)) { // if there is a NEXT iteration - std::get<0>(bcast_keys)[0] = {I, J, K+1}; + Key key0 = {I, J, K+1}; + auto bcast_keys = std::make_tuple(key0, std::array{{range1_low, range1_up}}); if (I == K + 1 && J == K + 1) { // in the next iteration, we have A function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<0>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<0, 4>(bcast_keys, std::move(m_ij), out); } else if (I == K + 1) { // in the next iteration, we have B function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<1>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<1, 4>(bcast_keys, std::move(m_ij), out); } else if (J == K + 1) { // in the next iteration, we have C function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<2>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<2, 4>(bcast_keys, std::move(m_ij), out); } else { // in the next iteration, we have D function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<3>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<3, 4>(bcast_keys, std::move(m_ij), out); } } else { // cout << "C[" << I << "," << J << "," << K << "]: " << m_ij << endl; - //::send<5>(Key(std::make_pair(std::make_pair(I, J), K)), m_ij, out); - std::get<0>(bcast_keys)[0] = {I, J, K}; + Key key0 = {I, J, K}; + auto bcast_keys = std::make_tuple(key0, std::array{{range1_low, range1_up}}); ::broadcast<5, 4>(bcast_keys, std::move(m_ij), out); } } @@ -532,9 +506,9 @@ class FuncD : public TT, const BlockMatrix, const BlockMatrix>& t, typename baseT::output_terminals_type& out) { - int I = key.execution_info.first.first; - int J = key.execution_info.first.second; - int K = key.execution_info.second; + int I = key.I; + int J = key.J; + int K = key.K; BlockMatrix m_ij; // Executing the update @@ -548,20 +522,20 @@ class FuncD : public TT(Key(std::make_pair(std::make_pair(I, J), K + 1)), std::move(m_ij), out); + ::send<0>(Key(I, J, K + 1), std::move(m_ij), out); } else if (I == K + 1) { // in the next iteration, we have B function call // cout << "Send " << I << " " << J << " " << K << endl; - ::send<1>(Key(std::make_pair(std::make_pair(I, J), K + 1)), std::move(m_ij), out); + ::send<1>(Key(I, J, K + 1), std::move(m_ij), out); } else if (J == K + 1) { // in the next iteration, we have C function call // cout << "Send " << I << " " << J << " " << K << endl; - ::send<2>(Key(std::make_pair(std::make_pair(I, J), K + 1)), std::move(m_ij), out); + ::send<2>(Key(I, J, K + 1), std::move(m_ij), out); } else { // in the next iteration, we have D function call // cout << "Send " << I << " " << J << " " << K << endl; - ::send<3>(Key(std::make_pair(std::make_pair(I, J), K + 1)), std::move(m_ij), out); + ::send<3>(Key(I, J, K + 1), std::move(m_ij), out); } } else { // cout << "D[" << I << "," << J << "," << K << "]: " << m_ij << endl; - ::send<4>(Key(std::make_pair(std::make_pair(I, J), K)), std::move(m_ij), out); + ::send<4>(Key(I, J, K), std::move(m_ij), out); } } }; @@ -724,8 +698,8 @@ int main(int argc, char** argv) { int P = std::sqrt(world.size()); int Q = world.size() / P; auto keymap = [=](const Key &key) { - int I = key.execution_info.first.first; - int J = key.execution_info.first.second; + int I = key.I; + int J = key.J; return ((I%P) + (J%Q)*P); }; diff --git a/ttg/CMakeLists.txt b/ttg/CMakeLists.txt index 49b4e07c63..9575a13a09 100644 --- a/ttg/CMakeLists.txt +++ b/ttg/CMakeLists.txt @@ -36,6 +36,7 @@ set(ttg-impl-headers ${CMAKE_CURRENT_SOURCE_DIR}/ttg/func.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/fwd.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/impl_selector.h + ${CMAKE_CURRENT_SOURCE_DIR}/ttg/keyrange.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/tt.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/reduce.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/run.h diff --git a/ttg/ttg.h b/ttg/ttg.h index e0fa9a7029..e840f2a475 100644 --- a/ttg/ttg.h +++ b/ttg/ttg.h @@ -17,6 +17,7 @@ #include "ttg/base/world.h" #include "ttg/broadcast.h" #include "ttg/func.h" +#include "ttg/keyrange.h" #include "ttg/reduce.h" #include "ttg/traverse.h" #include "ttg/tt.h" diff --git a/ttg/ttg/func.h b/ttg/ttg/func.h index 9849bfc528..4802ded503 100644 --- a/ttg/ttg/func.h +++ b/ttg/ttg/func.h @@ -238,7 +238,7 @@ namespace ttg { inline void broadcast(const std::tuple &keylists, valueT &&value, std::tuple...> &t) { if constexpr (ttg::meta::is_iterable_v>>) { - if (std::distance(std::begin(std::get(keylists)), std::end(std::get(keylists))) > 0) { + if (std::begin(std::get(keylists)) != std::end(std::get(keylists))) { std::get(t).broadcast(std::get(keylists), value); } } else { @@ -252,7 +252,7 @@ namespace ttg { template inline void broadcast(const std::tuple &keylists, valueT &&value) { if constexpr (ttg::meta::is_iterable_v>>) { - if (std::distance(std::begin(std::get(keylists)), std::end(std::get(keylists))) > 0) { + if (std::begin(std::get(keylists)) != std::end(std::get(keylists))) { using key_t = decltype(*std::begin(std::get(keylists))); auto *terminal_ptr = detail::get_out_terminal(i, "ttg::broadcast(keylists, value)"); terminal_ptr->broadcast(std::get(keylists), value); @@ -270,7 +270,7 @@ namespace ttg { template inline void broadcast(const std::tuple &keylists, std::tuple...> &t) { if constexpr (ttg::meta::is_iterable_v>>) { - if (std::distance(std::begin(std::get(keylists)), std::end(std::get(keylists))) > 0) { + if (std::begin(std::get(keylists)) != std::end(std::get(keylists))) { std::get(t).broadcast(std::get(keylists)); } } else { @@ -284,7 +284,7 @@ namespace ttg { template inline void broadcast(const std::tuple &keylists) { if constexpr (ttg::meta::is_iterable_v>>) { - if (std::distance(std::begin(std::get(keylists)), std::end(std::get(keylists))) > 0) { + if (std::begin(std::get(keylists)) != std::end(std::get(keylists))) { using key_t = decltype(*std::begin(std::get(keylists))); auto *terminal_ptr = detail::get_out_terminal(i, "ttg::broadcast(keylists)"); terminal_ptr->broadcast(std::get(keylists)); diff --git a/ttg/ttg/keyrange.h b/ttg/ttg/keyrange.h new file mode 100644 index 0000000000..602d4b1dbf --- /dev/null +++ b/ttg/ttg/keyrange.h @@ -0,0 +1,209 @@ +#ifndef TTG_KEYTERATOR_H +#define TTG_KEYTERATOR_H + +#include +#include +#include + +#include "ttg/util/meta.h" +#include "ttg/serialization.h" + +namespace ttg { + + /* Trait providing the diff type of a given key + * Defaults to the key itself. + * May be provided as \c diff_type member of the key or by + * overloading this trait. + */ + template + struct key_diff_type { + using type = Key; + }; + + /* Overload for Key::diff_type */ + template + struct key_diff_type>{ + using type = typename Key::diff_type; + }; + + /* Convenience type */ + template + using key_diff_type_t = typename key_diff_type::type; + + namespace detail { + /** + * Trait checking whether a key is compatible with the LinearKeyRange. + * Keys must at least support comparison as well as addition or compound addition. + */ + template + struct is_range_compatible { + using key_type = std::decay_t; + using difference_type = key_diff_type_t; + constexpr static bool value = ttg::meta::is_comparable_v + && (ttg::meta::has_addition_v + || ttg::meta::has_compound_addition_v) + && (std::is_trivially_copyable_v || is_user_buffer_serializable_v); + }; + + template + constexpr bool is_range_compatible_v = is_range_compatible::value; + + /** + * Represents a range of keys that can be represented as a linear iteration + * space, i.e., using a start and end (one past the last element) as well as + * a step increment. An iterator is provided to iterate over the range of keys. + * + * The step increment is of type \sa ttg::key_diff_type, which defaults + * to the key type but can be overridden by either defining a \c diff_type + * member type or by specializing ttg::key_diff_type. + */ + template + struct LinearKeyRange { + using key_type = std::decay_t; + using diff_type = key_diff_type_t; + + /* Forward Iterator for the linear key range */ + struct iterator + { + /* types for std::iterator_trait */ + using value_type = key_type; + using difference_type = diff_type; + using pointer = const value_type*; + using reference = const value_type&; + using iterator_category = std::forward_iterator_tag; + + iterator(const key_type& pos, const diff_type& inc) + : m_pos(pos), m_inc(inc) + { } + + iterator& operator++() { + if constexpr (meta::has_compound_addition_v) { + m_pos += m_inc; + } else if constexpr (meta::has_addition_v) { + m_pos = m_pos + m_inc; + } else { + throw std::logic_error("Key type does not support addition its with difference type"); + } + return *this; + } + + iterator operator++(int) { + iterator retval = *this; + ++(*this); + return retval; + } + + bool operator==(const iterator& other) const { + if constexpr(meta::is_comparable_v) { + return m_pos == other.m_pos; + } + return false; + } + + bool operator!=(const iterator& other) const { + return !(*this == other); + } + const key_type& operator*() const { + return m_pos; + } + const key_type* operator->() const { + return &m_pos; + } + + private: + key_type m_pos; + const diff_type& m_inc; + }; + + LinearKeyRange() + { } + + LinearKeyRange(const Key& begin, const Key& end, const diff_type& inc) + : m_begin(begin) + , m_end(end) + , m_inc(inc) + { } + + iterator begin() const { + return iterator(m_begin, m_inc); + } + + iterator end() const { + return iterator(m_end, m_inc); + } + +#ifdef TTG_SERIALIZATION_SUPPORTS_MADNESS + /* Make the LinearKeyRange madness serializable */ + template + && is_madness_buffer_serializable_v>> + void serialize(Archive& ar) { + ar & m_begin; + ar & m_end; + ar & m_inc; + } +#endif + +#ifdef TTG_SERIALIZATION_SUPPORTS_BOOST + /* Make the LinearKeyRange boost serializable */ + template + && is_boost_buffer_serializable_v>> + void serialize(Archive& ar, const unsigned int version) { + ar & m_begin; + ar & m_end; + ar & m_inc; + } +#endif + + friend std::ostream& operator<<(std::ostream& out, LinearKeyRange const& k) { + out << "LinearKeyRange[" << k.m_begin << "," << k.m_end << "):"< + inline auto make_keyrange(const Key& begin, + const Key& end, + const key_diff_type_t& inc) { + static_assert(detail::is_range_compatible_v, + "Key type does not support all required operations: operator==, " + "operator+ or operator+=, and serialization (trivially_copyable, madness, or boost)"); + return detail::LinearKeyRange(begin, end, inc); + } + + /** + * Create a key range [begin, end) with unit stride. + * + * Requires \c operator++ and \c operator- to be defined on Key. + * If a difference type (\c Key::diff_type) is defined, \c operator- should return + * the difference type. + * + * \return A representation of the range that can be passed to send/broadcast. + */ + template + inline auto make_keyrange(const Key& begin, const Key& end) { + static_assert(detail::is_range_compatible_v, + "Key type does not support all required operations: operator==, " + "operator+ or operator+=, and serialization (trivially_copyable, madness, or boost)"); + static_assert(meta::has_increment_v && meta::has_difference_v, + "Unit stride key range requires operator++ and operator- on Key"); + if constexpr (meta::has_pre_increment_v) { + return detail::LinearKeyRange(begin, end, ++begin - begin); + } else { + return detail::LinearKeyRange(begin, end, begin++ - begin); + } + } + +} // namespace ttg +#endif // TTG_KEYTERATOR_H diff --git a/ttg/ttg/parsec/ttg.h b/ttg/ttg/parsec/ttg.h index 5ff550a64c..0c70648f82 100644 --- a/ttg/ttg/parsec/ttg.h +++ b/ttg/ttg/parsec/ttg.h @@ -20,6 +20,7 @@ #include "ttg/runtimes.h" #include "ttg/terminal.h" #include "ttg/tt.h" +#include "ttg/util/bits.h" #include "ttg/util/env.h" #include "ttg/util/hash.h" #include "ttg/util/meta.h" @@ -79,8 +80,8 @@ namespace ttg_parsec { struct msg_header_t { typedef enum { MSG_SET_ARG = 0, MSG_SET_ARGSTREAM_SIZE = 1, MSG_FINALIZE_ARGSTREAM_SIZE = 2 } fn_id_t; + uint64_t tt_id; uint32_t taskpool_id; - uint64_t op_id; fn_id_t fn_id; int32_t param_id; int num_keys; @@ -93,12 +94,12 @@ namespace ttg_parsec { static_set_arg_fct_type static_set_arg_fct; parsec_taskpool_t *tp = NULL; msg_header_t *msg = static_cast(data); - uint64_t op_id = msg->op_id; + uint64_t tt_id = msg->tt_id; tp = parsec_taskpool_lookup(msg->taskpool_id); assert(NULL != tp); static_map_mutex.lock(); try { - auto op_pair = static_id_to_op_map.at(op_id); + auto op_pair = static_id_to_op_map.at(tt_id); static_map_mutex.unlock(); tp->tdm.module->incoming_message_start(tp, src_rank, NULL, NULL, 0, NULL); static_set_arg_fct = op_pair.first; @@ -110,8 +111,8 @@ namespace ttg_parsec { assert(data_cpy != 0); memcpy(data_cpy, data, size); ttg::trace("ttg_parsec(", ttg_default_execution_context().rank(), ") Delaying delivery of message (", src_rank, - ", ", op_id, ", ", data_cpy, ", ", size, ")"); - delayed_unpack_actions.insert(std::make_pair(op_id, std::make_tuple(src_rank, data_cpy, size))); + ", ", tt_id, ", ", data_cpy, ", ", size, ")"); + delayed_unpack_actions.insert(std::make_pair(tt_id, std::make_tuple(src_rank, data_cpy, size))); static_map_mutex.unlock(); return 1; } @@ -500,14 +501,15 @@ namespace ttg_parsec { .key_hash = parsec_hash_table_generic_64bits_key_hash}; template - class rma_delayed_activate { + class rma_delayed_activate_keylist { std::vector _keylist; std::atomic _outstanding_transfers; ActivationCallbackT _cb; ttg_data_copy_t *_copy; public: - rma_delayed_activate(std::vector &&key, ttg_data_copy_t *copy, int num_transfers, ActivationCallbackT cb) + rma_delayed_activate_keylist(std::vector &&key, ttg_data_copy_t *copy, + int num_transfers, ActivationCallbackT cb) : _keylist(std::move(key)), _outstanding_transfers(num_transfers), _cb(cb), _copy(copy) {} bool complete_transfer(void) { @@ -520,6 +522,51 @@ namespace ttg_parsec { } }; + template + class rma_delayed_activate_keyranges { + std::vector> _keyranges; + std::atomic _outstanding_transfers; + ActivationCallbackT _cb; + ttg_data_copy_t *_copy; + + public: + rma_delayed_activate_keyranges(std::vector> &&ranges, ttg_data_copy_t *copy, + int num_transfers, ActivationCallbackT cb) + : _keyranges(std::move(ranges)), _outstanding_transfers(num_transfers), _cb(cb), _copy(copy) {} + + bool complete_transfer(void) { + int left = --_outstanding_transfers; + if (0 == left) { + _cb(_keyranges, _copy); + return true; + } + return false; + } + }; + + template + class rma_delayed_activate_keyrange { + ttg::detail::LinearKeyRange _keyrange; + std::atomic _outstanding_transfers; + ActivationCallbackT _cb; + ttg_data_copy_t *_copy; + + public: + rma_delayed_activate_keyrange(ttg::detail::LinearKeyRange &range, ttg_data_copy_t *copy, + int num_transfers, ActivationCallbackT cb) + : _keyrange(range), _outstanding_transfers(num_transfers), _cb(cb), _copy(copy) {} + + bool complete_transfer(void) { + int left = --_outstanding_transfers; + if (0 == left) { + _cb(_keyrange, _copy); + return true; + } + return false; + } + }; + + template static int get_complete_cb(parsec_comm_engine_t *comm_engine, parsec_ce_mem_reg_handle_t lreg, ptrdiff_t ldispl, parsec_ce_mem_reg_handle_t rreg, ptrdiff_t rdispl, size_t size, int remote, @@ -732,7 +779,7 @@ namespace ttg_parsec { msg_t() = default; msg_t(uint64_t tt_id, uint32_t taskpool_id, msg_header_t::fn_id_t fn_id, int32_t param_id, int num_keys = 1) - : tt_id{taskpool_id, tt_id, fn_id, param_id, num_keys} {} + : tt_id{tt_id, taskpool_id, fn_id, param_id, num_keys} {} }; } // namespace detail @@ -997,8 +1044,8 @@ namespace ttg_parsec { return &mempools.thread_mempools[index]; } - template - void set_arg_from_msg_keylist(ttg::span &&keylist, ttg_data_copy_t *copy) { + template + void set_arg_from_msg_keylist(Iterator&& begin, Iterator&& end, ttg_data_copy_t *copy) { /* create a dummy task that holds the copy, which can be reused by others */ task_t *dummy; parsec_execution_stream_s *es = world.impl().execution_stream(); @@ -1015,8 +1062,11 @@ namespace ttg_parsec { /* iterate over the keys and have them use the copy we made */ parsec_task_t *task_ring = nullptr; - for (auto &&key : keylist) { - set_arg_local_impl(key, *reinterpret_cast(copy->device_private), copy, &task_ring); + for (auto it = begin; it != end; ++it) { + if constexpr (HasNonLocalKeys) { + if (keymap(*it) != world.rank()) continue; + } + set_arg_local_impl(*it, *reinterpret_cast(copy->device_private), copy, &task_ring); } if (nullptr != task_ring) { @@ -1032,6 +1082,37 @@ namespace ttg_parsec { parsec_thread_mempool_free(mempool, &dummy->parsec_task); } + inline auto extract_keylist_from_msg(detail::msg_t *msg, uint64_t& pos) { + std::vector keylist; + int num_keys = msg->tt_id.num_keys; + keylist.reserve(num_keys); + auto rank = world.rank(); + for (int k = 0; k < num_keys; ++k) { + keyT key; + pos = unpack(key, msg->bytes, pos); + assert(keymap(key) == rank); + keylist.push_back(std::move(key)); + } + return keylist; + } + + inline auto extract_keyrange_from_msg(detail::msg_t *msg, uint64_t& pos) { + ttg::detail::LinearKeyRange range; + pos = unpack(range, msg->bytes, pos); + return range; + } + + inline auto extract_keyranges_from_msg(detail::msg_t *msg, uint64_t& pos) { + int num_ranges = -msg->tt_id.num_keys; + std::vector> ranges(num_ranges); + for (int i = 0; i < num_ranges; ++i) { + ttg::detail::LinearKeyRange range; + pos = unpack(range, msg->bytes, pos); + ranges.push_back(std::move(range)); + } + return ranges; + } + // there are 6 types of set_arg: // - case 1: nonvoid Key, complete Value type // - case 2: nonvoid Key, void Value, mixed (data+control) inputs @@ -1049,15 +1130,17 @@ namespace ttg_parsec { if constexpr (!ttg::meta::is_void_v) { /* unpack the keys */ uint64_t pos = 0; + /* we can have either a list of keys, a single range, or a list of ranges */ std::vector keylist; + std::vector> keyranges; + ttg::detail::LinearKeyRange keyrange; int num_keys = msg->tt_id.num_keys; - keylist.reserve(num_keys); - auto rank = world.rank(); - for (int k = 0; k < num_keys; ++k) { - keyT key; - pos = unpack(key, msg->bytes, pos); - assert(keymap(key) == rank); - keylist.push_back(std::move(key)); + if (num_keys >= 0) { + keylist = extract_keylist_from_msg(msg, pos); + } else if (num_keys == -1) { + keyrange = extract_keyrange_from_msg(msg, pos); + } else { + keyranges = extract_keyranges_from_msg(msg, pos); } // case 1 if constexpr (!ttg::meta::is_void_v) { @@ -1065,9 +1148,17 @@ namespace ttg_parsec { if constexpr (!ttg::has_split_metadata::value) { ttg_data_copy_t *copy = detail::create_new_datacopy(decvalueT{}); unpack(*static_cast(copy->device_private), msg->bytes, pos); - - set_arg_from_msg_keylist(ttg::span(&keylist[0], num_keys), copy); + if (num_keys >= 0) { + set_arg_from_msg_keylist(keylist.begin(), keylist.end(), copy); + } else if (num_keys == -1) { + set_arg_from_msg_keylist(keyrange.begin(), keyrange.end(), copy); + } else { + for (const auto& range : keyranges) { + set_arg_from_msg_keylist(range.begin(), range.end(), copy); + } + } } else { + /* unpack the header and start the RMA transfers */ ttg::SplitMetadataDescriptor descr; using metadata_t = decltype(descr.get_metadata(std::declval())); @@ -1093,23 +1184,61 @@ namespace ttg_parsec { ttg_data_copy_t *copy = detail::create_new_datacopy(descr.create_from_metadata(metadata)); /* nothing else to do if the object is empty */ if (0 == num_iovecs) { - set_arg_from_msg_keylist(keylist, copy); + if (num_keys >= 0) { + set_arg_from_msg_keylist(keylist.begin(), keylist.end(), copy); + } else if (num_keys == -1) { + set_arg_from_msg_keylist(keyrange.begin(), keyrange.end(), copy); + } else { + for (const auto& range : keyranges) { + set_arg_from_msg_keylist(range.begin(), range.end(), copy); + } + } } else { /* extract the callback tag */ parsec_ce_tag_t cbtag; std::memcpy(&cbtag, msg->bytes + pos, sizeof(cbtag)); pos += sizeof(cbtag); + /* pick the right activation callback */ + void* activation_ptr; + parsec_ce_onesided_callback_t cb; + if (num_keys >= 0) { + auto activation = new detail::rma_delayed_activate_keylist( + std::move(keylist), copy, num_iovecs, + [this](std::vector &&keylist, ttg_data_copy_t *copy) { + set_arg_from_msg_keylist(keylist.begin(), keylist.end(), copy); + this->world.impl().decrement_inflight_msg(); + }); + activation_ptr = activation; + using ActivationT = std::decay_t; + cb = &detail::get_complete_cb; + } else if (num_keys == -1) { + auto activation = new detail::rma_delayed_activate_keyrange( + keyrange, copy, num_iovecs, + [this](ttg::detail::LinearKeyRange &keyrange, ttg_data_copy_t *copy) { + set_arg_from_msg_keylist(keyrange.begin(), keyrange.end(), copy); + this->world.impl().decrement_inflight_msg(); + }); + activation_ptr = activation; + using ActivationT = std::decay_t; + cb = &detail::get_complete_cb; + } else { + auto activation = new detail::rma_delayed_activate_keyranges( + std::move(keyranges), copy, num_iovecs, + [this](std::vector> &keyranges, ttg_data_copy_t *copy) { + for (const auto& range : keyranges) { + set_arg_from_msg_keylist(range.begin(), range.end(), copy); + } + this->world.impl().decrement_inflight_msg(); + }); + activation_ptr = activation; + using ActivationT = std::decay_t; + cb = &detail::get_complete_cb; + } + /* create the value from the metadata */ - auto activation = new detail::rma_delayed_activate( - std::move(keylist), copy, num_iovecs, [this](std::vector &&keylist, ttg_data_copy_t *copy) { - set_arg_from_msg_keylist(keylist, copy); - this->world.impl().decrement_inflight_msg(); - }); auto &val = *static_cast(copy->device_private); - using ActivationT = std::decay_t; - int nv = 0; /* process payload iovecs */ auto iovecs = descr.get_data(val); @@ -1136,7 +1265,7 @@ namespace ttg_parsec { world.impl().increment_inflight_msg(); /* TODO: PaRSEC should treat the remote callback as a tag, not a function pointer! */ parsec_ce.get(&parsec_ce, lreg, 0, rreg, 0, iov.num_bytes, remote, - &detail::get_complete_cb, activation, + cb, activation_ptr, /*world.impl().parsec_ttg_rma_tag()*/ cbtag, &fn_ptr, sizeof(std::intptr_t)); } @@ -1147,8 +1276,20 @@ namespace ttg_parsec { } // case 2 and 3 } else if constexpr (!ttg::meta::is_void_v && std::is_void_v) { - for (auto &&key : keylist) { - set_arg(key, ttg::Void{}); + if (num_keys >= 0) { + for (auto &&key : keylist) { + set_arg(key, ttg::Void{}); + } + } else if (num_keys == -1) { + for (const auto &key : keyrange) { + set_arg(key, ttg::Void{}); + } + } else { + for (const auto& range : keyranges) { + for (const auto &key : range) { + set_arg(key, ttg::Void{}); + } + } } } // case 4 @@ -1571,12 +1712,11 @@ namespace ttg_parsec { parsec_taskpool_t *tp = world_impl.taskpool(); tp->tdm.module->outgoing_message_start(tp, owner, NULL); tp->tdm.module->outgoing_message_pack(tp, owner, NULL, NULL, 0); - // std::cout << "Sending AM with " << msg->op_id.num_keys << " keys " << std::endl; parsec_ce.send_am(&parsec_ce, world_impl.parsec_ttg_tag(), owner, static_cast(msg.get()), sizeof(msg_header_t) + pos); } - template + template void broadcast_arg_local(Iterator &&begin, Iterator &&end, const Value &value) { parsec_task_t *task_ring = nullptr; ttg_data_copy_t *copy = nullptr; @@ -1585,6 +1725,9 @@ namespace ttg_parsec { } for (auto it = begin; it != end; ++it) { + if constexpr (HasNonLocalKeys) { + if (keymap(*it) != world.rank()) continue; + } set_arg_local_impl(*it, value, copy, &task_ring); } /* submit all ready tasks at once */ @@ -1657,10 +1800,133 @@ namespace ttg_parsec { sizeof(msg_header_t) + pos); } /* handle local keys */ - broadcast_arg_local(local_begin, local_end, value); + broadcast_arg_local(local_begin, local_end, value); } else { /* only local keys */ - broadcast_arg_local(keylist.begin(), keylist.end(), value); + broadcast_arg_local(keylist.begin(), keylist.end(), value); + } + } + + + template + std::enable_if_t && !std::is_void_v> && + !ttg::has_split_metadata>::value, + void> + rangecast_arg(const ttg::span> &ranges, const Value &value) { + auto world = ttg_default_execution_context(); + int rank = world.rank(); + + auto rankset = ttg::bitset(world.size()); + /* flag every rank for which we have a key */ + for (const auto& range : ranges) { + for (const auto& key : range) { + rankset.set(keymap(key)); + } + } + bool have_remote = rankset.popcnt() > 1 || (rankset.popcnt() == 1 && !rankset.get(rank)); + + if (have_remote) { + + size_t pos = 0; + using msg_t = detail::msg_t; + auto &world_impl = world.impl(); + std::unique_ptr msg = std::make_unique(get_instance_id(), world_impl.taskpool()->taskpool_id, + msg_header_t::MSG_SET_ARG, i); + + parsec_taskpool_t *tp = world_impl.taskpool(); + + for (int owner = 0; owner < world.size(); ++owner) { + if (rank == owner || !rankset[owner]) continue; + + /* pack the key range */ + int num_ranges = 0; + for (const auto& range : ranges) { + pos = pack(range, msg->bytes, pos); + ++num_ranges; + } + msg->tt_id.num_keys = -num_ranges; // mark as ranges + + /* TODO: use RMA to transfer large values */ + pos = pack(value, msg->bytes, pos); + + /* Send the message */ + tp->tdm.module->outgoing_message_start(tp, owner, NULL); + tp->tdm.module->outgoing_message_pack(tp, owner, NULL, NULL, 0); + parsec_ce.send_am(&parsec_ce, world_impl.parsec_ttg_tag(), owner, static_cast(msg.get()), + sizeof(msg_header_t) + pos); + } + /* handle local keys */ + if (rankset[rank]) { + for (const auto& range : ranges) { + broadcast_arg_local(range.begin(), range.end(), value); + } + } + } else { + /* only local keys */ + for (const auto& range : ranges) { + broadcast_arg_local(range.begin(), range.end(), value); + } + } + } + + + template + inline + auto splitmd_get_iovs(const Value &value, + ttg::SplitMetadataDescriptor &descr) { + auto iovs = descr.get_data(*const_cast(&value)); + int32_t num_iovs = std::distance(std::begin(iovs), std::end(iovs)); + std::vector>> memregs; + memregs.reserve(num_iovs); + + /* register all iovs so the registration can be reused */ + for (auto &&iov : iovs) { + parsec_ce_mem_reg_handle_t lreg; + size_t lreg_size; + parsec_ce.mem_register(iov.data, PARSEC_MEM_TYPE_NONCONTIGUOUS, iov.num_bytes, parsec_datatype_int8_t, + iov.num_bytes, &lreg, &lreg_size); + /* TODO: use a static function for deregistration here? */ + memregs.push_back(std::make_pair(static_cast(lreg_size), + /* TODO: this assumes that parsec_ce_mem_reg_handle_t is void* */ + std::shared_ptr{lreg, [](void *ptr) { + parsec_ce_mem_reg_handle_t memreg = + (parsec_ce_mem_reg_handle_t)ptr; + parsec_ce.mem_unregister(&memreg); + }})); + } + return memregs; + } + + /** + * pack the registration handles + * memory layout: [, ...] + */ + template + inline + void splitmd_pack_reghandles(std::vector>> memregs, + detail::msg_t* msg, size_t& pos) { + + for (auto &&memreg : memregs) { + int32_t lreg_size; + std::shared_ptr lreg_ptr; + std::tie(lreg_size, lreg_ptr) = memreg; + std::memcpy(msg->bytes + pos, &lreg_size, sizeof(lreg_size)); + pos += sizeof(lreg_size); + std::memcpy(msg->bytes + pos, lreg_ptr.get(), lreg_size); + pos += lreg_size; + /* create a function that will be invoked upon RMA completion at the target */ + std::shared_ptr lreg_ptr_v = lreg_ptr; + /* mark another reader on the copy */ + ttg_data_copy_t *copy = detail::register_data_copy(copy, nullptr, true); + std::function *fn = new std::function([=]() mutable { + /* shared_ptr of value and registration captured by value so resetting + * them here will eventually release the memory/registration */ + detail::release_data_copy(copy); + lreg_ptr_v.reset(); + }); + std::intptr_t fn_ptr{reinterpret_cast(fn)}; + std::memcpy(msg->bytes + pos, &fn_ptr, sizeof(fn_ptr)); + pos += sizeof(fn_ptr); } } @@ -1691,26 +1957,8 @@ namespace ttg_parsec { auto local_end = keylist_sorted.end(); ttg::SplitMetadataDescriptor descr; - auto iovs = descr.get_data(*const_cast(&value)); - int32_t num_iovs = std::distance(std::begin(iovs), std::end(iovs)); - std::vector>> memregs; - memregs.reserve(num_iovs); - - /* register all iovs so the registration can be reused */ - for (auto &&iov : iovs) { - parsec_ce_mem_reg_handle_t lreg; - size_t lreg_size; - parsec_ce.mem_register(iov.data, PARSEC_MEM_TYPE_NONCONTIGUOUS, iov.num_bytes, parsec_datatype_int8_t, - iov.num_bytes, &lreg, &lreg_size); - /* TODO: use a static function for deregistration here? */ - memregs.push_back(std::make_pair(static_cast(lreg_size), - /* TODO: this assumes that parsec_ce_mem_reg_handle_t is void* */ - std::shared_ptr{lreg, [](void *ptr) { - parsec_ce_mem_reg_handle_t memreg = - (parsec_ce_mem_reg_handle_t)ptr; - parsec_ce.mem_unregister(&memreg); - }})); - } + std::vector>> memregs = splitmd_get_iovs(value, descr); + size_t num_iovs = memregs.size(); using msg_t = detail::msg_t; auto &world_impl = world.impl(); @@ -1763,47 +2011,103 @@ namespace ttg_parsec { parsec_ce_tag_t cbtag = reinterpret_cast(&detail::get_remote_complete_cb); std::memcpy(msg->bytes + pos, &cbtag, sizeof(cbtag)); pos += sizeof(cbtag); + splitmd_pack_reghandles(memregs, msg.get(), pos); + tp->tdm.module->outgoing_message_start(tp, owner, NULL); + tp->tdm.module->outgoing_message_pack(tp, owner, NULL, NULL, 0); + parsec_ce.send_am(&parsec_ce, world_impl.parsec_ttg_tag(), owner, static_cast(msg.get()), + sizeof(msg_header_t) + pos); + } + /* handle local keys */ + broadcast_arg_local(local_begin, local_end, value); + } else { + /* handle local keys */ + broadcast_arg_local(keylist.begin(), keylist.end(), value); + } + } - /** - * pack the registration handles - * memory layout: [, ...] - */ - int idx = 0; - for (auto &&iov : iovs) { - // auto [lreg_size, lreg_ptr] = memregs[idx]; - int32_t lreg_size; - std::shared_ptr lreg_ptr; - std::tie(lreg_size, lreg_ptr) = memregs[idx]; - std::memcpy(msg->bytes + pos, &lreg_size, sizeof(lreg_size)); - pos += sizeof(lreg_size); - std::memcpy(msg->bytes + pos, lreg_ptr.get(), lreg_size); - pos += lreg_size; - /* create a function that will be invoked upon RMA completion at the target */ - std::shared_ptr lreg_ptr_v = lreg_ptr; - /* mark another reader on the copy */ - copy = detail::register_data_copy(copy, nullptr, true); - std::function *fn = new std::function([=]() mutable { - /* shared_ptr of value and registration captured by value so resetting - * them here will eventually release the memory/registration */ - detail::release_data_copy(copy); - lreg_ptr_v.reset(); - }); - std::intptr_t fn_ptr{reinterpret_cast(fn)}; - std::memcpy(msg->bytes + pos, &fn_ptr, sizeof(fn_ptr)); - pos += sizeof(fn_ptr); - ++idx; + template + std::enable_if_t && !std::is_void_v> && + ttg::has_split_metadata>::value, + void> + splitmd_rangecast_arg(const ttg::span> &ranges, const Value &value) { + using valueT = std::tuple_element_t; + auto world = ttg_default_execution_context(); + int rank = world.rank(); + auto rankset = ttg::bitset(world.size()); + /* flag every rank for which we have a key */ + for (const auto& range : ranges) { + for (const auto& key : range) { + rankset.set(keymap(key)); + } + } + bool have_remote = rankset.popcnt() > 1 || (rankset.popcnt() == 1 && !rankset[rank]); + + if (have_remote) { + using decvalueT = std::decay_t; + + ttg::SplitMetadataDescriptor descr; + std::vector>> memregs = splitmd_get_iovs(value, descr); + size_t num_iovs = memregs.size(); + + using msg_t = detail::msg_t; + auto &world_impl = world.impl(); + std::unique_ptr msg = std::make_unique(get_instance_id(), world_impl.taskpool()->taskpool_id, + msg_header_t::MSG_SET_ARG, i); + auto metadata = descr.get_metadata(value); + size_t metadata_size = sizeof(metadata); + + ttg_data_copy_t *copy; + copy = detail::find_copy_in_task(parsec_ttg_caller, &value); + assert(nullptr != copy); + + parsec_taskpool_t *tp = world_impl.taskpool(); + size_t pos = 0; + for (int owner = 0; owner < world.size(); ++owner) { + if (rank == owner || !rankset[owner]) continue; + + int num_ranges = 0; + for (const auto& range : ranges) { + /* pack the key range objects */ + pos = pack(range, msg->bytes, pos); + ++num_ranges; } + msg->tt_id.num_keys = -num_ranges; // mark as keyrange + + /* pack the metadata */ + std::memcpy(msg->bytes + pos, &metadata, metadata_size); + pos += metadata_size; + /* pack the local rank */ + int rank = world.rank(); + std::memcpy(msg->bytes + pos, &rank, sizeof(rank)); + pos += sizeof(rank); + /* pack the number of iovecs */ + std::memcpy(msg->bytes + pos, &num_iovs, sizeof(num_iovs)); + pos += sizeof(num_iovs); + + /* TODO: at the moment, the tag argument to parsec_ce.get() is treated as a + * raw function pointer instead of a preregistered AM tag, so play that game. + * Once this is fixed in PaRSEC we need to use parsec_ttg_rma_tag instead! */ + parsec_ce_tag_t cbtag = reinterpret_cast(&detail::get_remote_complete_cb); + std::memcpy(msg->bytes + pos, &cbtag, sizeof(cbtag)); + pos += sizeof(cbtag); + splitmd_pack_reghandles(memregs, msg.get(), pos); tp->tdm.module->outgoing_message_start(tp, owner, NULL); tp->tdm.module->outgoing_message_pack(tp, owner, NULL, NULL, 0); parsec_ce.send_am(&parsec_ce, world_impl.parsec_ttg_tag(), owner, static_cast(msg.get()), sizeof(msg_header_t) + pos); } /* handle local keys */ - broadcast_arg_local(local_begin, local_end, value); + if (rankset[world.rank()]) { + for (const auto& range : ranges) { + broadcast_arg_local(range.begin(), range.end(), value); + } + } } else { - /* handle local keys */ - broadcast_arg_local(keylist.begin(), keylist.end(), value); + for (const auto& range : ranges) { + broadcast_arg_local(range.begin(), range.end(), value); + } } + } // Used by invoke to set all arguments associated with a task @@ -2091,9 +2395,17 @@ namespace ttg_parsec { broadcast_arg(keylist, value); } }; + auto rangecast_callback = [this](const ttg::span> &keylist, + const valueT &value) { + if constexpr (ttg::has_split_metadata>::value) { + splitmd_rangecast_arg(keylist, value); + } else { + rangecast_arg(keylist, value); + } + }; auto setsize_callback = [this](const keyT &key, std::size_t size) { set_argstream_size(key, size); }; auto finalize_callback = [this](const keyT &key) { finalize_argstream(key); }; - input.set_callback(send_callback, move_callback, broadcast_callback, setsize_callback, finalize_callback); + input.set_callback(send_callback, move_callback, broadcast_callback, setsize_callback, finalize_callback, rangecast_callback); } ////////////////////////////////////////////////////////////////// // case 2: nonvoid key, void value, mixed inputs diff --git a/ttg/ttg/terminal.h b/ttg/ttg/terminal.h index 7cea6003e6..a4329b7f35 100644 --- a/ttg/ttg/terminal.h +++ b/ttg/ttg/terminal.h @@ -7,6 +7,7 @@ #include "ttg/base/terminal.h" #include "ttg/fwd.h" +#include "ttg/keyrange.h" #include "ttg/util/demangle.h" #include "ttg/util/meta.h" #include "ttg/util/trace.h" @@ -85,6 +86,7 @@ namespace ttg { using send_callback_type = meta::detail::send_callback_t>; using move_callback_type = meta::detail::move_callback_t>; using broadcast_callback_type = meta::detail::broadcast_callback_t>; + using rangecast_callback_type = meta::detail::rangecast_callback_t>; using setsize_callback_type = typename base_type::setsize_callback_type; using finalize_callback_type = typename base_type::finalize_callback_type; static constexpr bool is_an_input_terminal = true; @@ -93,6 +95,7 @@ namespace ttg { send_callback_type send_callback; move_callback_type move_callback; broadcast_callback_type broadcast_callback; + rangecast_callback_type rangecast_callback; // No moving, copying, assigning permitted In(In &&other) = delete; @@ -110,10 +113,12 @@ namespace ttg { void set_callback(const send_callback_type &send_callback, const move_callback_type &move_callback, const broadcast_callback_type &bcast_callback = broadcast_callback_type{}, const setsize_callback_type &setsize_callback = setsize_callback_type{}, - const finalize_callback_type &finalize_callback = finalize_callback_type{}) { + const finalize_callback_type &finalize_callback = finalize_callback_type{}, + const rangecast_callback_type &rangecast_callback = rangecast_callback_type{}) { this->send_callback = send_callback; this->move_callback = move_callback; this->broadcast_callback = bcast_callback; + this->rangecast_callback = rangecast_callback; base_type::set_callback(setsize_callback, finalize_callback); } @@ -157,7 +162,9 @@ namespace ttg { // An optimized implementation will need a separate callback for broadcast // with a specific value for rangeT template - std::enable_if_t, void> broadcast(const rangeT &keylist, const Value &value) { + std::enable_if_t && + !meta::is_iterable_of_v>, void> + broadcast(const rangeT &keylist, const Value &value) { if (broadcast_callback) { if constexpr (ttg::meta::is_iterable_v) { broadcast_callback(ttg::span(&(*std::begin(keylist)), std::distance(std::begin(keylist), std::end(keylist))), @@ -176,29 +183,10 @@ namespace ttg { } } - template - std::enable_if_t, void> broadcast(const rangeT &keylist, Value &&value) { - const Value &v = value; - if (broadcast_callback) { - if constexpr (ttg::meta::is_iterable_v) { - broadcast_callback( - ttg::span(&(*std::begin(keylist)), std::distance(std::begin(keylist), std::end(keylist))), v); - } else { - /* got something we cannot iterate over (single element?) so put one element in the span */ - broadcast_callback(ttg::span(&keylist, 1), v); - } - } else { - if constexpr (ttg::meta::is_iterable_v) { - for (auto &&key : keylist) send(key, v); - } else { - /* got something we cannot iterate over (single element?) so put one element in the span */ - broadcast_callback(ttg::span(&keylist, 1), v); - } - } - } - template - std::enable_if_t, void> broadcast(const rangeT &keylist) { + std::enable_if_t && + !meta::is_iterable_of_v>, void> + broadcast(const rangeT &keylist) { if (broadcast_callback) { if constexpr (ttg::meta::is_iterable_v) { broadcast_callback( @@ -216,6 +204,55 @@ namespace ttg { } } } + + /** + * Overload for key ranges + */ + template + std::enable_if_t, void> broadcast(const ttg::detail::LinearKeyRange &range, const Value &value) { + if (rangecast_callback) { + rangecast_callback(ttg::span(range, 1), value); + } else { + for (const auto &key : range) send(key, value); + } + } + + template + std::enable_if_t, void> broadcast(const ttg::detail::LinearKeyRange &range) { + if (rangecast_callback) { + rangecast_callback(ttg::span(range, 1)); + } else { + for (const auto &key : range) sendk(key); + } + } + + template + std::enable_if_t && + meta::is_iterable_of_v>, void> + broadcast(const rangeT &rangelist, const Value &value) { + if (rangecast_callback) { + rangecast_callback(ttg::span(&(*std::begin(rangelist)), std::distance(std::begin(rangelist), std::end(rangelist))), + value); + } else { + for (const auto& range : rangelist) + for (const auto &key : range) + send(key, value); + } + } + + template + std::enable_if_t && + meta::is_iterable_of_v>, void> + broadcast(const rangeT &rangelist) { + if (rangecast_callback) { + rangecast_callback(ttg::span(&(*std::begin(rangelist)), std::distance(std::begin(rangelist), std::end(rangelist)))); + } else { + for (const auto& range : rangelist) + for (const auto &key : range) + sendk(key); + } + } + }; template diff --git a/ttg/ttg/util/bits.h b/ttg/ttg/util/bits.h new file mode 100644 index 0000000000..dee2c2950f --- /dev/null +++ b/ttg/ttg/util/bits.h @@ -0,0 +1,84 @@ +#ifndef TTG_BITS_H +#define TTG_BITS_H + +#ifdef __cpp_lib_bitops +#include +#endif // __cpp_lib_bitops + +#include + +/** + * Implement some functions of the header introduced in C++20. + * + * Also provides a dynamically allocated bitset. + */ + +namespace ttg { + +#ifdef __cpp_lib_bitops + template< class T > + constexpr int popcount( T x ) noexcept { + return std::popcount(x); + } +#else + template< class T > + constexpr int popcount( T x ) noexcept { + int res = 0; + for (int i = 0; i < sizeof(T)*8; ++i) { + res += !!(x & (1< m_storage; + mutable ssize_t m_popcnt = -1; + constexpr static size_t storage_size = sizeof(storage_type); + + public: + bitset(size_t size) : m_storage((size+storage_size-1)/storage_size) + { } + + void set(size_t i) noexcept { + m_storage[i/storage_size] |= 1<<(i%storage_size); + m_popcnt = -1; + } + + bool get(size_t i) const noexcept { + return !!(m_storage[i/storage_size] & (1<<(i%storage_size))); + } + + bool operator[](size_t i) const noexcept { + return get(i); + } + + size_t size() const noexcept { + return m_storage.size()*storage_size; + } + + size_t popcnt() const noexcept { + if (m_popcnt == -1) { + m_popcnt = 0; + for (const auto& v : m_storage) { + m_popcnt += ttg::popcount(v); + } + } + return m_popcnt; + } + + void clear() noexcept { + size_t size = m_storage.size(); + m_storage.clear(); + m_storage.resize(size); + } + }; + + +} // namespace ttg + +#endif // TTG_BITS_H diff --git a/ttg/ttg/util/meta.h b/ttg/ttg/util/meta.h index 179bf2dbc5..79c2cfcca9 100644 --- a/ttg/ttg/util/meta.h +++ b/ttg/ttg/util/meta.h @@ -11,6 +11,11 @@ namespace ttg { class Void; + namespace detail { + template + struct LinearKeyRange; + } // namespace detail + namespace meta { #if __cplusplus >= 201703L @@ -650,6 +655,31 @@ namespace ttg { template using broadcast_callback_t = typename broadcast_callback::type; + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // rangecast_callback_t = std::function, protected against void key or value + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + struct rangecast_callback; + template + struct rangecast_callback && !is_void_v>> { + using type = std::function> &, const Value &)>; + }; + template + struct rangecast_callback && is_void_v>> { + using type = std::function> &)>; + }; + template + struct rangecast_callback && !is_void_v>> { + using type = std::function; + }; + template + struct rangecast_callback && is_void_v>> { + using type = std::function; + }; + template + using rangecast_callback_t = typename rangecast_callback::type; + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // setsize_callback_t = std::function protected against void key //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -742,6 +772,21 @@ namespace ttg { template constexpr bool is_iterable_v = is_iterable::value; + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // check whether a type is iterable yielding a specified value type + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + struct is_iterable_of : std::false_type {}; + + // this gets used only when we can call std::begin() and std::end() on that type + template + struct is_iterable_of()))>> + && std::is_same_v()))>>>> + : std::true_type {}; + + template + constexpr bool is_iterable_of_v = is_iterable_of::value; + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // check whether a Callable is invocable with the arguments given as a typelist //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -755,6 +800,89 @@ namespace ttg { constexpr bool is_invocable_typelist_r_v> = std::is_invocable_r_v; + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // check whether a type has operator++ defined + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + struct has_pre_increment : std::false_type + { }; + + template + struct has_pre_increment()))>> + : std::true_type + { }; + + template + constexpr bool has_pre_increment_v = has_pre_increment::value; + + template + struct has_post_increment : std::false_type + { }; + + template + struct has_post_increment())++)>> + : std::true_type + { }; + + template + constexpr bool has_post_increment_v = has_post_increment::value; + + template + constexpr bool has_increment_v = has_pre_increment::value || has_post_increment::value; + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // check whether a type has operator- and operator+ defined + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + struct has_difference : std::false_type + { }; + + template + struct has_difference()) - (std::declval()))>> + : std::true_type + { }; + + template + constexpr bool has_difference_v = has_difference::value; + + template + struct has_addition : std::false_type + { }; + + template + struct has_addition()) + (std::declval()))>> + : std::true_type + { }; + + template + constexpr bool has_addition_v = has_addition::value; + + template + struct has_compound_addition : std::false_type + { }; + + template + struct has_compound_addition()) += (std::declval()))>> + : std::true_type + { }; + + template + constexpr bool has_compound_addition_v = has_compound_addition::value; + + /* Check if a type is comparable with itself. + * TODO: use of C++20 is_comparable once it's available */ + template + struct is_comparable : std::false_type + { }; + + template + struct is_comparable()) == (std::declval()))>> + : std::true_type + { }; + + template + constexpr bool is_comparable_v = is_comparable::value; + } // namespace meta } // namespace ttg