diff --git a/examples/floyd-warshall/floyd_warshall_df.cc b/examples/floyd-warshall/floyd_warshall_df.cc index 1d523f4fcf..b42c42ee19 100644 --- a/examples/floyd-warshall/floyd_warshall_df.cc +++ b/examples/floyd-warshall/floyd_warshall_df.cc @@ -38,28 +38,37 @@ using namespace ttg; struct Key { // ((I, J), K) where (I, J) is the tile coordinate and K is the iteration number - std::pair, int> execution_info; + //std::pair, int> execution_info; + int I = 0, J = 0, K = 0; bool operator==(const Key& b) const { if (this == &b) return true; - return execution_info.first.first == b.execution_info.first.first && - execution_info.first.second == b.execution_info.first.second && - execution_info.second == b.execution_info.second; + return I == b.I && + J == b.J && + K == b.K; + } + + Key& operator+=(const Key& b) { + I += b.I; + J += b.J; + K += b.K; + rehash(); + return *this; } bool operator!=(const Key& b) const { return !((*this) == b); } madness::hashT hash_val; - Key() : execution_info(std::make_pair(std::make_pair(0, 0), 0)) { rehash(); } - Key(const std::pair, int>& e) : execution_info(e) { rehash(); } - Key(int e_f_f, int e_f_s, int e_s) : execution_info(std::make_pair(std::make_pair(e_f_f, e_f_s), e_s)) { rehash(); } + Key() { rehash(); } + //Key(const std::pair, int>& e) : execution_info(e) { rehash(); } + Key(int e_f_f, int e_f_s, int e_s) : I(e_f_f), J(e_f_s), K(e_s) { rehash(); } madness::hashT hash() const { return hash_val; } void rehash() { std::hash int_hasher; - hash_val = int_hasher(execution_info.first.first) * 2654435769 + int_hasher(execution_info.first.second) * 40503 + - int_hasher(execution_info.second); + hash_val = int_hasher(I) * 2654435769 + int_hasher(J) * 40503 + + int_hasher(K); } #ifdef TTG_SERIALIZATION_SUPPORTS_MADNESS @@ -72,14 +81,16 @@ struct Key { #ifdef TTG_SERIALIZATION_SUPPORTS_BOOST template void serialize(Archive& ar, const unsigned int) { - ar& execution_info; + ar& I; + ar& J; + ar& K; if constexpr (ttg::detail::is_boost_input_archive_v) rehash(); } #endif friend std::ostream& operator<<(std::ostream& out, Key const& k) { - out << "Key((" << k.execution_info.first.first << "," << k.execution_info.first.second << ")," - << k.execution_info.second << ")"; + out << "Key(" << k.I << "," << k.J << "," + << k.K << ")"; return out; } }; @@ -126,13 +137,13 @@ class Initiator : public TT(Key(std::make_pair(std::make_pair(i, j), 0)), kv.second, out); + ::send<0>(Key(i, j, 0), kv.second, out); } else if (i == 0) { // B function call - ::send<1>(Key(std::make_pair(std::make_pair(i, j), 0)), kv.second, out); + ::send<1>(Key(i, j, 0), kv.second, out); } else if (j == 0) { // C function call - ::send<2>(Key(std::make_pair(std::make_pair(i, j), 0)), kv.second, out); + ::send<2>(Key(i, j, 0), kv.second, out); } else { // D function call - ::send<3>(Key(std::make_pair(std::make_pair(i, j), 0)), kv.second, out); + ::send<3>(Key(i, j, 0), kv.second, out); } }); } @@ -177,8 +188,8 @@ class Finalizer : public TT, Finalizer, ttg::typelist &&in -- use direct arguments void op(const Key& key, const std::tuple>&& t, typename baseT::output_terminals_type& out) { - int I = key.execution_info.first.first; - int J = key.execution_info.first.second; + int I = key.I; + int J = key.J; int block_size = problem_size / blocking_factor; BlockMatrix bm = get<0>(t); @@ -253,9 +264,9 @@ class FuncA : public TT>& t, typename baseT::output_terminals_type& out) { - int I = key.execution_info.first.first; - int J = key.execution_info.first.second; - int K = key.execution_info.second; + int I = key.I; + int J = key.J; + int K = key.K; BlockMatrix m_ij; // Executing the update @@ -264,53 +275,35 @@ class FuncA : public TT(t)), (get<0>(t)), (get<0>(t))); // cout << "A[" << I << "," << J << "," << K << "]: " << m_ij << endl; } - // Making u_ready/v_ready for all the B/C function calls in the CURRENT iteration - std::tuple, std::vector, std::vector> bcast_keys; - for (int l = 0; l < blocking_factor; ++l) { - if (l != K) { - /*if (K == 0) { - // B calls - x_ready - ::send<1>(Key(std::make_pair(std::make_pair(I, l), K)), (*adjacency_matrix_ttg)(I,l), out); - // C calls - x_ready - ::send<2>(Key(std::make_pair(std::make_pair(l, J), K)), (*adjacency_matrix_ttg)(l,J), out); - }*/ - // B calls - // cout << "Send " << I << " " << l << " " << K << endl; - //::send<4>(Key(std::make_pair(std::make_pair(I, l), K)), m_ij, out); - std::get<1>(bcast_keys).emplace_back(I, l, K); - - // C calls - // cout << "Send " << l << " " << J << " " << K << endl; - //::send<5>(Key(std::make_pair(std::make_pair(l, J), K)), m_ij, out); - std::get<2>(bcast_keys).emplace_back(l, J, K); - } - } + /* create 2 ranges: 1) (I, l, K) with l = 0..K-1; 2) (I, l, K) with l = K+1..blocking_factor */ + auto range1_low = ttg::make_keyrange(Key(I, 0, K), Key(I, K, K), Key(0, 1, 0)); + auto range1_up = ttg::make_keyrange(Key(I, K+1, K), Key(I, blocking_factor, K), Key(0, 1, 0)); + + /* create 2 ranges: 1) (l, J, K) with l = 0..K-1; 2) (l, J, K) with l = K+1..blocking_factor */ + auto range2_low = ttg::make_keyrange(Key(0, J, K), Key(K, J, K), Key(1, 0, 0)); + auto range2_up = ttg::make_keyrange(Key(K+1, J, K), Key(blocking_factor, J, K), Key(1, 0, 0)); // making x_ready for the computation on the SAME block in the NEXT iteration if (K < (blocking_factor - 1)) { // if there is a NEXT iteration - std::get<0>(bcast_keys)[0] = {I, J, K+1}; + Key key0 = {I, J, K+1}; + auto bcast_keys = std::make_tuple(key0, + std::array{{range1_low, range1_up}}, + std::array{{range2_low, range2_up}}); if (I == K + 1 && J == K + 1) { // in the next iteration, we have A function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<0>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<0, 4, 5>(bcast_keys, std::move(m_ij), out); } else if (I == K + 1) { // in the next iteration, we have B function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<1>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<1, 4, 5>(bcast_keys, std::move(m_ij), out); } else if (J == K + 1) { // in the next iteration, we have C function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<2>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<2, 4, 5>(bcast_keys, std::move(m_ij), out); } else { // in the next iteration, we have D function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<3>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<3, 4, 5>(bcast_keys, std::move(m_ij), out); } } else { - std::get<0>(bcast_keys)[0] = {I, J, K}; - // cout << "A[" << I << "," << J << "," << K << "]: " << m_ij << endl; - //::send<6>(Key(std::make_pair(std::make_pair(I, J), K)), m_ij, out); + Key key0 = {I, J, K}; + auto bcast_keys = std::make_tuple(key0, + std::array{{range1_low, range1_up}}, + std::array{{range2_low, range2_up}}); ::broadcast<6, 4, 5>(bcast_keys, std::move(m_ij), out); } } @@ -353,9 +346,9 @@ class FuncB : public TT, const BlockMatrix>& t, typename baseT::output_terminals_type& out) { - int I = key.execution_info.first.first; - int J = key.execution_info.first.second; - int K = key.execution_info.second; + int I = key.I; + int J = key.J; + int K = key.K; BlockMatrix m_ij; // Executing the update @@ -366,20 +359,14 @@ class FuncB : public TT, std::vector> bcast_keys; - for (int i = 0; i < blocking_factor; ++i) { - if (i != I) { - // if (K == 0) - //::send<3>(Key(std::make_pair(std::make_pair(i, J), K)), (*adjacency_matrix_ttg)(i,J), out); - // cout << "Send " << i << " " << J << " " << K << endl; - //::send<4>(Key(std::make_pair(std::make_pair(i, J), K)), m_ij, out); - std::get<1>(bcast_keys).emplace_back(i, J, K); - } - } + /* create 2 ranges: 1) (I, l, K) with l = 0..K-1; 2) (I, l, K) with l = K+1..blocking_factor */ + auto range1_low = ttg::make_keyrange(Key(0, J, K), Key(I, J, K), Key(1, 0, 0)); + auto range1_up = ttg::make_keyrange(Key(I+1, J, K), Key(blocking_factor, J, K), Key(1, 0, 0)); // making x_ready for the computation on the SAME block in the NEXT iteration if (K < (blocking_factor - 1)) { // if there is a NEXT iteration - std::get<0>(bcast_keys)[0] = {I, J, K+1}; + Key key0 = {I, J, K+1}; + auto bcast_keys = std::make_tuple(key0, std::array{{range1_low, range1_up}}); if (I == K + 1 && J == K + 1) { // in the next iteration, we have A function call // cout << "Send " << I << " " << J << " " << K << endl; //::send<0>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); @@ -398,7 +385,8 @@ class FuncB : public TT(bcast_keys, std::move(m_ij), out); } } else { - std::get<0>(bcast_keys)[0] = {I, J, K}; + Key key0 = {I, J, K}; + auto bcast_keys = std::make_tuple(key0, std::array{{range1_low, range1_up}}); // cout << "B[" << I << "," << J << "," << K << "]: " << m_ij << endl; //::send<5>(Key(std::make_pair(std::make_pair(I, J), K)), m_ij, out); ::broadcast<5, 4>(bcast_keys, std::move(m_ij), out); @@ -443,9 +431,9 @@ class FuncC : public TT, BlockMatrix>& t, typename baseT::output_terminals_type& out) { - int I = key.execution_info.first.first; - int J = key.execution_info.first.second; - int K = key.execution_info.second; + int I = key.I; + int J = key.J; + int K = key.K; BlockMatrix m_ij; // Executing the update @@ -456,40 +444,26 @@ class FuncC : public TT, std::vector> bcast_keys; - for (int j = 0; j < blocking_factor; ++j) { - if (j != J) { - //::send<4>(Key(std::make_pair(std::make_pair(I, j), K)), (*adjacency_matrix_ttg)(I,j), out); - // cout << "Send " << I << " " << j << " " << K << endl; - //::send<4>(Key(std::make_pair(std::make_pair(I, j), K)), m_ij, out); - std::get<1>(bcast_keys).emplace_back(I, j, K); - } - } + auto range1_low = ttg::make_keyrange(Key(I, 0, K), Key(I, J, K), Key(0, 1, 0)); + auto range1_up = ttg::make_keyrange(Key(I, J+1, K), Key(I, blocking_factor, K), Key(0, 1, 0)); // making x_ready for the computation on the SAME block in the NEXT iteration if (K < (blocking_factor - 1)) { // if there is a NEXT iteration - std::get<0>(bcast_keys)[0] = {I, J, K+1}; + Key key0 = {I, J, K+1}; + auto bcast_keys = std::make_tuple(key0, std::array{{range1_low, range1_up}}); if (I == K + 1 && J == K + 1) { // in the next iteration, we have A function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<0>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<0, 4>(bcast_keys, std::move(m_ij), out); } else if (I == K + 1) { // in the next iteration, we have B function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<1>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<1, 4>(bcast_keys, std::move(m_ij), out); } else if (J == K + 1) { // in the next iteration, we have C function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<2>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<2, 4>(bcast_keys, std::move(m_ij), out); } else { // in the next iteration, we have D function call - // cout << "Send " << I << " " << J << " " << K << endl; - //::send<3>(Key(std::make_pair(std::make_pair(I, J), K + 1)), m_ij, out); ::broadcast<3, 4>(bcast_keys, std::move(m_ij), out); } } else { // cout << "C[" << I << "," << J << "," << K << "]: " << m_ij << endl; - //::send<5>(Key(std::make_pair(std::make_pair(I, J), K)), m_ij, out); - std::get<0>(bcast_keys)[0] = {I, J, K}; + Key key0 = {I, J, K}; + auto bcast_keys = std::make_tuple(key0, std::array{{range1_low, range1_up}}); ::broadcast<5, 4>(bcast_keys, std::move(m_ij), out); } } @@ -532,9 +506,9 @@ class FuncD : public TT, const BlockMatrix, const BlockMatrix>& t, typename baseT::output_terminals_type& out) { - int I = key.execution_info.first.first; - int J = key.execution_info.first.second; - int K = key.execution_info.second; + int I = key.I; + int J = key.J; + int K = key.K; BlockMatrix m_ij; // Executing the update @@ -548,20 +522,20 @@ class FuncD : public TT(Key(std::make_pair(std::make_pair(I, J), K + 1)), std::move(m_ij), out); + ::send<0>(Key(I, J, K + 1), std::move(m_ij), out); } else if (I == K + 1) { // in the next iteration, we have B function call // cout << "Send " << I << " " << J << " " << K << endl; - ::send<1>(Key(std::make_pair(std::make_pair(I, J), K + 1)), std::move(m_ij), out); + ::send<1>(Key(I, J, K + 1), std::move(m_ij), out); } else if (J == K + 1) { // in the next iteration, we have C function call // cout << "Send " << I << " " << J << " " << K << endl; - ::send<2>(Key(std::make_pair(std::make_pair(I, J), K + 1)), std::move(m_ij), out); + ::send<2>(Key(I, J, K + 1), std::move(m_ij), out); } else { // in the next iteration, we have D function call // cout << "Send " << I << " " << J << " " << K << endl; - ::send<3>(Key(std::make_pair(std::make_pair(I, J), K + 1)), std::move(m_ij), out); + ::send<3>(Key(I, J, K + 1), std::move(m_ij), out); } } else { // cout << "D[" << I << "," << J << "," << K << "]: " << m_ij << endl; - ::send<4>(Key(std::make_pair(std::make_pair(I, J), K)), std::move(m_ij), out); + ::send<4>(Key(I, J, K), std::move(m_ij), out); } } }; @@ -724,8 +698,8 @@ int main(int argc, char** argv) { int P = std::sqrt(world.size()); int Q = world.size() / P; auto keymap = [=](const Key &key) { - int I = key.execution_info.first.first; - int J = key.execution_info.first.second; + int I = key.I; + int J = key.J; return ((I%P) + (J%Q)*P); }; diff --git a/ttg/CMakeLists.txt b/ttg/CMakeLists.txt index 49b4e07c63..9575a13a09 100644 --- a/ttg/CMakeLists.txt +++ b/ttg/CMakeLists.txt @@ -36,6 +36,7 @@ set(ttg-impl-headers ${CMAKE_CURRENT_SOURCE_DIR}/ttg/func.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/fwd.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/impl_selector.h + ${CMAKE_CURRENT_SOURCE_DIR}/ttg/keyrange.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/tt.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/reduce.h ${CMAKE_CURRENT_SOURCE_DIR}/ttg/run.h diff --git a/ttg/ttg.h b/ttg/ttg.h index e0fa9a7029..e840f2a475 100644 --- a/ttg/ttg.h +++ b/ttg/ttg.h @@ -17,6 +17,7 @@ #include "ttg/base/world.h" #include "ttg/broadcast.h" #include "ttg/func.h" +#include "ttg/keyrange.h" #include "ttg/reduce.h" #include "ttg/traverse.h" #include "ttg/tt.h" diff --git a/ttg/ttg/func.h b/ttg/ttg/func.h index 9849bfc528..4802ded503 100644 --- a/ttg/ttg/func.h +++ b/ttg/ttg/func.h @@ -238,7 +238,7 @@ namespace ttg { inline void broadcast(const std::tuple &keylists, valueT &&value, std::tuple...> &t) { if constexpr (ttg::meta::is_iterable_v>>) { - if (std::distance(std::begin(std::get(keylists)), std::end(std::get(keylists))) > 0) { + if (std::begin(std::get(keylists)) != std::end(std::get(keylists))) { std::get(t).broadcast(std::get(keylists), value); } } else { @@ -252,7 +252,7 @@ namespace ttg { template inline void broadcast(const std::tuple &keylists, valueT &&value) { if constexpr (ttg::meta::is_iterable_v>>) { - if (std::distance(std::begin(std::get(keylists)), std::end(std::get(keylists))) > 0) { + if (std::begin(std::get(keylists)) != std::end(std::get(keylists))) { using key_t = decltype(*std::begin(std::get(keylists))); auto *terminal_ptr = detail::get_out_terminal(i, "ttg::broadcast(keylists, value)"); terminal_ptr->broadcast(std::get(keylists), value); @@ -270,7 +270,7 @@ namespace ttg { template inline void broadcast(const std::tuple &keylists, std::tuple...> &t) { if constexpr (ttg::meta::is_iterable_v>>) { - if (std::distance(std::begin(std::get(keylists)), std::end(std::get(keylists))) > 0) { + if (std::begin(std::get(keylists)) != std::end(std::get(keylists))) { std::get(t).broadcast(std::get(keylists)); } } else { @@ -284,7 +284,7 @@ namespace ttg { template inline void broadcast(const std::tuple &keylists) { if constexpr (ttg::meta::is_iterable_v>>) { - if (std::distance(std::begin(std::get(keylists)), std::end(std::get(keylists))) > 0) { + if (std::begin(std::get(keylists)) != std::end(std::get(keylists))) { using key_t = decltype(*std::begin(std::get(keylists))); auto *terminal_ptr = detail::get_out_terminal(i, "ttg::broadcast(keylists)"); terminal_ptr->broadcast(std::get(keylists)); diff --git a/ttg/ttg/keyrange.h b/ttg/ttg/keyrange.h new file mode 100644 index 0000000000..602d4b1dbf --- /dev/null +++ b/ttg/ttg/keyrange.h @@ -0,0 +1,209 @@ +#ifndef TTG_KEYTERATOR_H +#define TTG_KEYTERATOR_H + +#include +#include +#include + +#include "ttg/util/meta.h" +#include "ttg/serialization.h" + +namespace ttg { + + /* Trait providing the diff type of a given key + * Defaults to the key itself. + * May be provided as \c diff_type member of the key or by + * overloading this trait. + */ + template + struct key_diff_type { + using type = Key; + }; + + /* Overload for Key::diff_type */ + template + struct key_diff_type>{ + using type = typename Key::diff_type; + }; + + /* Convenience type */ + template + using key_diff_type_t = typename key_diff_type::type; + + namespace detail { + /** + * Trait checking whether a key is compatible with the LinearKeyRange. + * Keys must at least support comparison as well as addition or compound addition. + */ + template + struct is_range_compatible { + using key_type = std::decay_t; + using difference_type = key_diff_type_t; + constexpr static bool value = ttg::meta::is_comparable_v + && (ttg::meta::has_addition_v + || ttg::meta::has_compound_addition_v) + && (std::is_trivially_copyable_v || is_user_buffer_serializable_v); + }; + + template + constexpr bool is_range_compatible_v = is_range_compatible::value; + + /** + * Represents a range of keys that can be represented as a linear iteration + * space, i.e., using a start and end (one past the last element) as well as + * a step increment. An iterator is provided to iterate over the range of keys. + * + * The step increment is of type \sa ttg::key_diff_type, which defaults + * to the key type but can be overridden by either defining a \c diff_type + * member type or by specializing ttg::key_diff_type. + */ + template + struct LinearKeyRange { + using key_type = std::decay_t; + using diff_type = key_diff_type_t; + + /* Forward Iterator for the linear key range */ + struct iterator + { + /* types for std::iterator_trait */ + using value_type = key_type; + using difference_type = diff_type; + using pointer = const value_type*; + using reference = const value_type&; + using iterator_category = std::forward_iterator_tag; + + iterator(const key_type& pos, const diff_type& inc) + : m_pos(pos), m_inc(inc) + { } + + iterator& operator++() { + if constexpr (meta::has_compound_addition_v) { + m_pos += m_inc; + } else if constexpr (meta::has_addition_v) { + m_pos = m_pos + m_inc; + } else { + throw std::logic_error("Key type does not support addition its with difference type"); + } + return *this; + } + + iterator operator++(int) { + iterator retval = *this; + ++(*this); + return retval; + } + + bool operator==(const iterator& other) const { + if constexpr(meta::is_comparable_v) { + return m_pos == other.m_pos; + } + return false; + } + + bool operator!=(const iterator& other) const { + return !(*this == other); + } + const key_type& operator*() const { + return m_pos; + } + const key_type* operator->() const { + return &m_pos; + } + + private: + key_type m_pos; + const diff_type& m_inc; + }; + + LinearKeyRange() + { } + + LinearKeyRange(const Key& begin, const Key& end, const diff_type& inc) + : m_begin(begin) + , m_end(end) + , m_inc(inc) + { } + + iterator begin() const { + return iterator(m_begin, m_inc); + } + + iterator end() const { + return iterator(m_end, m_inc); + } + +#ifdef TTG_SERIALIZATION_SUPPORTS_MADNESS + /* Make the LinearKeyRange madness serializable */ + template + && is_madness_buffer_serializable_v>> + void serialize(Archive& ar) { + ar & m_begin; + ar & m_end; + ar & m_inc; + } +#endif + +#ifdef TTG_SERIALIZATION_SUPPORTS_BOOST + /* Make the LinearKeyRange boost serializable */ + template + && is_boost_buffer_serializable_v>> + void serialize(Archive& ar, const unsigned int version) { + ar & m_begin; + ar & m_end; + ar & m_inc; + } +#endif + + friend std::ostream& operator<<(std::ostream& out, LinearKeyRange const& k) { + out << "LinearKeyRange[" << k.m_begin << "," << k.m_end << "):"< + inline auto make_keyrange(const Key& begin, + const Key& end, + const key_diff_type_t& inc) { + static_assert(detail::is_range_compatible_v, + "Key type does not support all required operations: operator==, " + "operator+ or operator+=, and serialization (trivially_copyable, madness, or boost)"); + return detail::LinearKeyRange(begin, end, inc); + } + + /** + * Create a key range [begin, end) with unit stride. + * + * Requires \c operator++ and \c operator- to be defined on Key. + * If a difference type (\c Key::diff_type) is defined, \c operator- should return + * the difference type. + * + * \return A representation of the range that can be passed to send/broadcast. + */ + template + inline auto make_keyrange(const Key& begin, const Key& end) { + static_assert(detail::is_range_compatible_v, + "Key type does not support all required operations: operator==, " + "operator+ or operator+=, and serialization (trivially_copyable, madness, or boost)"); + static_assert(meta::has_increment_v && meta::has_difference_v, + "Unit stride key range requires operator++ and operator- on Key"); + if constexpr (meta::has_pre_increment_v) { + return detail::LinearKeyRange(begin, end, ++begin - begin); + } else { + return detail::LinearKeyRange(begin, end, begin++ - begin); + } + } + +} // namespace ttg +#endif // TTG_KEYTERATOR_H diff --git a/ttg/ttg/parsec/ttg.h b/ttg/ttg/parsec/ttg.h index 5ff550a64c..0c70648f82 100644 --- a/ttg/ttg/parsec/ttg.h +++ b/ttg/ttg/parsec/ttg.h @@ -20,6 +20,7 @@ #include "ttg/runtimes.h" #include "ttg/terminal.h" #include "ttg/tt.h" +#include "ttg/util/bits.h" #include "ttg/util/env.h" #include "ttg/util/hash.h" #include "ttg/util/meta.h" @@ -79,8 +80,8 @@ namespace ttg_parsec { struct msg_header_t { typedef enum { MSG_SET_ARG = 0, MSG_SET_ARGSTREAM_SIZE = 1, MSG_FINALIZE_ARGSTREAM_SIZE = 2 } fn_id_t; + uint64_t tt_id; uint32_t taskpool_id; - uint64_t op_id; fn_id_t fn_id; int32_t param_id; int num_keys; @@ -93,12 +94,12 @@ namespace ttg_parsec { static_set_arg_fct_type static_set_arg_fct; parsec_taskpool_t *tp = NULL; msg_header_t *msg = static_cast(data); - uint64_t op_id = msg->op_id; + uint64_t tt_id = msg->tt_id; tp = parsec_taskpool_lookup(msg->taskpool_id); assert(NULL != tp); static_map_mutex.lock(); try { - auto op_pair = static_id_to_op_map.at(op_id); + auto op_pair = static_id_to_op_map.at(tt_id); static_map_mutex.unlock(); tp->tdm.module->incoming_message_start(tp, src_rank, NULL, NULL, 0, NULL); static_set_arg_fct = op_pair.first; @@ -110,8 +111,8 @@ namespace ttg_parsec { assert(data_cpy != 0); memcpy(data_cpy, data, size); ttg::trace("ttg_parsec(", ttg_default_execution_context().rank(), ") Delaying delivery of message (", src_rank, - ", ", op_id, ", ", data_cpy, ", ", size, ")"); - delayed_unpack_actions.insert(std::make_pair(op_id, std::make_tuple(src_rank, data_cpy, size))); + ", ", tt_id, ", ", data_cpy, ", ", size, ")"); + delayed_unpack_actions.insert(std::make_pair(tt_id, std::make_tuple(src_rank, data_cpy, size))); static_map_mutex.unlock(); return 1; } @@ -500,14 +501,15 @@ namespace ttg_parsec { .key_hash = parsec_hash_table_generic_64bits_key_hash}; template - class rma_delayed_activate { + class rma_delayed_activate_keylist { std::vector _keylist; std::atomic _outstanding_transfers; ActivationCallbackT _cb; ttg_data_copy_t *_copy; public: - rma_delayed_activate(std::vector &&key, ttg_data_copy_t *copy, int num_transfers, ActivationCallbackT cb) + rma_delayed_activate_keylist(std::vector &&key, ttg_data_copy_t *copy, + int num_transfers, ActivationCallbackT cb) : _keylist(std::move(key)), _outstanding_transfers(num_transfers), _cb(cb), _copy(copy) {} bool complete_transfer(void) { @@ -520,6 +522,51 @@ namespace ttg_parsec { } }; + template + class rma_delayed_activate_keyranges { + std::vector> _keyranges; + std::atomic _outstanding_transfers; + ActivationCallbackT _cb; + ttg_data_copy_t *_copy; + + public: + rma_delayed_activate_keyranges(std::vector> &&ranges, ttg_data_copy_t *copy, + int num_transfers, ActivationCallbackT cb) + : _keyranges(std::move(ranges)), _outstanding_transfers(num_transfers), _cb(cb), _copy(copy) {} + + bool complete_transfer(void) { + int left = --_outstanding_transfers; + if (0 == left) { + _cb(_keyranges, _copy); + return true; + } + return false; + } + }; + + template + class rma_delayed_activate_keyrange { + ttg::detail::LinearKeyRange _keyrange; + std::atomic _outstanding_transfers; + ActivationCallbackT _cb; + ttg_data_copy_t *_copy; + + public: + rma_delayed_activate_keyrange(ttg::detail::LinearKeyRange &range, ttg_data_copy_t *copy, + int num_transfers, ActivationCallbackT cb) + : _keyrange(range), _outstanding_transfers(num_transfers), _cb(cb), _copy(copy) {} + + bool complete_transfer(void) { + int left = --_outstanding_transfers; + if (0 == left) { + _cb(_keyrange, _copy); + return true; + } + return false; + } + }; + + template static int get_complete_cb(parsec_comm_engine_t *comm_engine, parsec_ce_mem_reg_handle_t lreg, ptrdiff_t ldispl, parsec_ce_mem_reg_handle_t rreg, ptrdiff_t rdispl, size_t size, int remote, @@ -732,7 +779,7 @@ namespace ttg_parsec { msg_t() = default; msg_t(uint64_t tt_id, uint32_t taskpool_id, msg_header_t::fn_id_t fn_id, int32_t param_id, int num_keys = 1) - : tt_id{taskpool_id, tt_id, fn_id, param_id, num_keys} {} + : tt_id{tt_id, taskpool_id, fn_id, param_id, num_keys} {} }; } // namespace detail @@ -997,8 +1044,8 @@ namespace ttg_parsec { return &mempools.thread_mempools[index]; } - template - void set_arg_from_msg_keylist(ttg::span &&keylist, ttg_data_copy_t *copy) { + template + void set_arg_from_msg_keylist(Iterator&& begin, Iterator&& end, ttg_data_copy_t *copy) { /* create a dummy task that holds the copy, which can be reused by others */ task_t *dummy; parsec_execution_stream_s *es = world.impl().execution_stream(); @@ -1015,8 +1062,11 @@ namespace ttg_parsec { /* iterate over the keys and have them use the copy we made */ parsec_task_t *task_ring = nullptr; - for (auto &&key : keylist) { - set_arg_local_impl(key, *reinterpret_cast(copy->device_private), copy, &task_ring); + for (auto it = begin; it != end; ++it) { + if constexpr (HasNonLocalKeys) { + if (keymap(*it) != world.rank()) continue; + } + set_arg_local_impl(*it, *reinterpret_cast(copy->device_private), copy, &task_ring); } if (nullptr != task_ring) { @@ -1032,6 +1082,37 @@ namespace ttg_parsec { parsec_thread_mempool_free(mempool, &dummy->parsec_task); } + inline auto extract_keylist_from_msg(detail::msg_t *msg, uint64_t& pos) { + std::vector keylist; + int num_keys = msg->tt_id.num_keys; + keylist.reserve(num_keys); + auto rank = world.rank(); + for (int k = 0; k < num_keys; ++k) { + keyT key; + pos = unpack(key, msg->bytes, pos); + assert(keymap(key) == rank); + keylist.push_back(std::move(key)); + } + return keylist; + } + + inline auto extract_keyrange_from_msg(detail::msg_t *msg, uint64_t& pos) { + ttg::detail::LinearKeyRange range; + pos = unpack(range, msg->bytes, pos); + return range; + } + + inline auto extract_keyranges_from_msg(detail::msg_t *msg, uint64_t& pos) { + int num_ranges = -msg->tt_id.num_keys; + std::vector> ranges(num_ranges); + for (int i = 0; i < num_ranges; ++i) { + ttg::detail::LinearKeyRange range; + pos = unpack(range, msg->bytes, pos); + ranges.push_back(std::move(range)); + } + return ranges; + } + // there are 6 types of set_arg: // - case 1: nonvoid Key, complete Value type // - case 2: nonvoid Key, void Value, mixed (data+control) inputs @@ -1049,15 +1130,17 @@ namespace ttg_parsec { if constexpr (!ttg::meta::is_void_v) { /* unpack the keys */ uint64_t pos = 0; + /* we can have either a list of keys, a single range, or a list of ranges */ std::vector keylist; + std::vector> keyranges; + ttg::detail::LinearKeyRange keyrange; int num_keys = msg->tt_id.num_keys; - keylist.reserve(num_keys); - auto rank = world.rank(); - for (int k = 0; k < num_keys; ++k) { - keyT key; - pos = unpack(key, msg->bytes, pos); - assert(keymap(key) == rank); - keylist.push_back(std::move(key)); + if (num_keys >= 0) { + keylist = extract_keylist_from_msg(msg, pos); + } else if (num_keys == -1) { + keyrange = extract_keyrange_from_msg(msg, pos); + } else { + keyranges = extract_keyranges_from_msg(msg, pos); } // case 1 if constexpr (!ttg::meta::is_void_v) { @@ -1065,9 +1148,17 @@ namespace ttg_parsec { if constexpr (!ttg::has_split_metadata::value) { ttg_data_copy_t *copy = detail::create_new_datacopy(decvalueT{}); unpack(*static_cast(copy->device_private), msg->bytes, pos); - - set_arg_from_msg_keylist(ttg::span(&keylist[0], num_keys), copy); + if (num_keys >= 0) { + set_arg_from_msg_keylist(keylist.begin(), keylist.end(), copy); + } else if (num_keys == -1) { + set_arg_from_msg_keylist(keyrange.begin(), keyrange.end(), copy); + } else { + for (const auto& range : keyranges) { + set_arg_from_msg_keylist(range.begin(), range.end(), copy); + } + } } else { + /* unpack the header and start the RMA transfers */ ttg::SplitMetadataDescriptor descr; using metadata_t = decltype(descr.get_metadata(std::declval())); @@ -1093,23 +1184,61 @@ namespace ttg_parsec { ttg_data_copy_t *copy = detail::create_new_datacopy(descr.create_from_metadata(metadata)); /* nothing else to do if the object is empty */ if (0 == num_iovecs) { - set_arg_from_msg_keylist(keylist, copy); + if (num_keys >= 0) { + set_arg_from_msg_keylist(keylist.begin(), keylist.end(), copy); + } else if (num_keys == -1) { + set_arg_from_msg_keylist(keyrange.begin(), keyrange.end(), copy); + } else { + for (const auto& range : keyranges) { + set_arg_from_msg_keylist(range.begin(), range.end(), copy); + } + } } else { /* extract the callback tag */ parsec_ce_tag_t cbtag; std::memcpy(&cbtag, msg->bytes + pos, sizeof(cbtag)); pos += sizeof(cbtag); + /* pick the right activation callback */ + void* activation_ptr; + parsec_ce_onesided_callback_t cb; + if (num_keys >= 0) { + auto activation = new detail::rma_delayed_activate_keylist( + std::move(keylist), copy, num_iovecs, + [this](std::vector &&keylist, ttg_data_copy_t *copy) { + set_arg_from_msg_keylist(keylist.begin(), keylist.end(), copy); + this->world.impl().decrement_inflight_msg(); + }); + activation_ptr = activation; + using ActivationT = std::decay_t; + cb = &detail::get_complete_cb; + } else if (num_keys == -1) { + auto activation = new detail::rma_delayed_activate_keyrange( + keyrange, copy, num_iovecs, + [this](ttg::detail::LinearKeyRange &keyrange, ttg_data_copy_t *copy) { + set_arg_from_msg_keylist(keyrange.begin(), keyrange.end(), copy); + this->world.impl().decrement_inflight_msg(); + }); + activation_ptr = activation; + using ActivationT = std::decay_t; + cb = &detail::get_complete_cb; + } else { + auto activation = new detail::rma_delayed_activate_keyranges( + std::move(keyranges), copy, num_iovecs, + [this](std::vector> &keyranges, ttg_data_copy_t *copy) { + for (const auto& range : keyranges) { + set_arg_from_msg_keylist(range.begin(), range.end(), copy); + } + this->world.impl().decrement_inflight_msg(); + }); + activation_ptr = activation; + using ActivationT = std::decay_t; + cb = &detail::get_complete_cb; + } + /* create the value from the metadata */ - auto activation = new detail::rma_delayed_activate( - std::move(keylist), copy, num_iovecs, [this](std::vector &&keylist, ttg_data_copy_t *copy) { - set_arg_from_msg_keylist(keylist, copy); - this->world.impl().decrement_inflight_msg(); - }); auto &val = *static_cast(copy->device_private); - using ActivationT = std::decay_t; - int nv = 0; /* process payload iovecs */ auto iovecs = descr.get_data(val); @@ -1136,7 +1265,7 @@ namespace ttg_parsec { world.impl().increment_inflight_msg(); /* TODO: PaRSEC should treat the remote callback as a tag, not a function pointer! */ parsec_ce.get(&parsec_ce, lreg, 0, rreg, 0, iov.num_bytes, remote, - &detail::get_complete_cb, activation, + cb, activation_ptr, /*world.impl().parsec_ttg_rma_tag()*/ cbtag, &fn_ptr, sizeof(std::intptr_t)); } @@ -1147,8 +1276,20 @@ namespace ttg_parsec { } // case 2 and 3 } else if constexpr (!ttg::meta::is_void_v && std::is_void_v) { - for (auto &&key : keylist) { - set_arg(key, ttg::Void{}); + if (num_keys >= 0) { + for (auto &&key : keylist) { + set_arg(key, ttg::Void{}); + } + } else if (num_keys == -1) { + for (const auto &key : keyrange) { + set_arg(key, ttg::Void{}); + } + } else { + for (const auto& range : keyranges) { + for (const auto &key : range) { + set_arg(key, ttg::Void{}); + } + } } } // case 4 @@ -1571,12 +1712,11 @@ namespace ttg_parsec { parsec_taskpool_t *tp = world_impl.taskpool(); tp->tdm.module->outgoing_message_start(tp, owner, NULL); tp->tdm.module->outgoing_message_pack(tp, owner, NULL, NULL, 0); - // std::cout << "Sending AM with " << msg->op_id.num_keys << " keys " << std::endl; parsec_ce.send_am(&parsec_ce, world_impl.parsec_ttg_tag(), owner, static_cast(msg.get()), sizeof(msg_header_t) + pos); } - template + template void broadcast_arg_local(Iterator &&begin, Iterator &&end, const Value &value) { parsec_task_t *task_ring = nullptr; ttg_data_copy_t *copy = nullptr; @@ -1585,6 +1725,9 @@ namespace ttg_parsec { } for (auto it = begin; it != end; ++it) { + if constexpr (HasNonLocalKeys) { + if (keymap(*it) != world.rank()) continue; + } set_arg_local_impl(*it, value, copy, &task_ring); } /* submit all ready tasks at once */ @@ -1657,10 +1800,133 @@ namespace ttg_parsec { sizeof(msg_header_t) + pos); } /* handle local keys */ - broadcast_arg_local(local_begin, local_end, value); + broadcast_arg_local(local_begin, local_end, value); } else { /* only local keys */ - broadcast_arg_local(keylist.begin(), keylist.end(), value); + broadcast_arg_local(keylist.begin(), keylist.end(), value); + } + } + + + template + std::enable_if_t && !std::is_void_v> && + !ttg::has_split_metadata>::value, + void> + rangecast_arg(const ttg::span> &ranges, const Value &value) { + auto world = ttg_default_execution_context(); + int rank = world.rank(); + + auto rankset = ttg::bitset(world.size()); + /* flag every rank for which we have a key */ + for (const auto& range : ranges) { + for (const auto& key : range) { + rankset.set(keymap(key)); + } + } + bool have_remote = rankset.popcnt() > 1 || (rankset.popcnt() == 1 && !rankset.get(rank)); + + if (have_remote) { + + size_t pos = 0; + using msg_t = detail::msg_t; + auto &world_impl = world.impl(); + std::unique_ptr msg = std::make_unique(get_instance_id(), world_impl.taskpool()->taskpool_id, + msg_header_t::MSG_SET_ARG, i); + + parsec_taskpool_t *tp = world_impl.taskpool(); + + for (int owner = 0; owner < world.size(); ++owner) { + if (rank == owner || !rankset[owner]) continue; + + /* pack the key range */ + int num_ranges = 0; + for (const auto& range : ranges) { + pos = pack(range, msg->bytes, pos); + ++num_ranges; + } + msg->tt_id.num_keys = -num_ranges; // mark as ranges + + /* TODO: use RMA to transfer large values */ + pos = pack(value, msg->bytes, pos); + + /* Send the message */ + tp->tdm.module->outgoing_message_start(tp, owner, NULL); + tp->tdm.module->outgoing_message_pack(tp, owner, NULL, NULL, 0); + parsec_ce.send_am(&parsec_ce, world_impl.parsec_ttg_tag(), owner, static_cast(msg.get()), + sizeof(msg_header_t) + pos); + } + /* handle local keys */ + if (rankset[rank]) { + for (const auto& range : ranges) { + broadcast_arg_local(range.begin(), range.end(), value); + } + } + } else { + /* only local keys */ + for (const auto& range : ranges) { + broadcast_arg_local(range.begin(), range.end(), value); + } + } + } + + + template + inline + auto splitmd_get_iovs(const Value &value, + ttg::SplitMetadataDescriptor &descr) { + auto iovs = descr.get_data(*const_cast(&value)); + int32_t num_iovs = std::distance(std::begin(iovs), std::end(iovs)); + std::vector>> memregs; + memregs.reserve(num_iovs); + + /* register all iovs so the registration can be reused */ + for (auto &&iov : iovs) { + parsec_ce_mem_reg_handle_t lreg; + size_t lreg_size; + parsec_ce.mem_register(iov.data, PARSEC_MEM_TYPE_NONCONTIGUOUS, iov.num_bytes, parsec_datatype_int8_t, + iov.num_bytes, &lreg, &lreg_size); + /* TODO: use a static function for deregistration here? */ + memregs.push_back(std::make_pair(static_cast(lreg_size), + /* TODO: this assumes that parsec_ce_mem_reg_handle_t is void* */ + std::shared_ptr{lreg, [](void *ptr) { + parsec_ce_mem_reg_handle_t memreg = + (parsec_ce_mem_reg_handle_t)ptr; + parsec_ce.mem_unregister(&memreg); + }})); + } + return memregs; + } + + /** + * pack the registration handles + * memory layout: [, ...] + */ + template + inline + void splitmd_pack_reghandles(std::vector>> memregs, + detail::msg_t* msg, size_t& pos) { + + for (auto &&memreg : memregs) { + int32_t lreg_size; + std::shared_ptr lreg_ptr; + std::tie(lreg_size, lreg_ptr) = memreg; + std::memcpy(msg->bytes + pos, &lreg_size, sizeof(lreg_size)); + pos += sizeof(lreg_size); + std::memcpy(msg->bytes + pos, lreg_ptr.get(), lreg_size); + pos += lreg_size; + /* create a function that will be invoked upon RMA completion at the target */ + std::shared_ptr lreg_ptr_v = lreg_ptr; + /* mark another reader on the copy */ + ttg_data_copy_t *copy = detail::register_data_copy(copy, nullptr, true); + std::function *fn = new std::function([=]() mutable { + /* shared_ptr of value and registration captured by value so resetting + * them here will eventually release the memory/registration */ + detail::release_data_copy(copy); + lreg_ptr_v.reset(); + }); + std::intptr_t fn_ptr{reinterpret_cast(fn)}; + std::memcpy(msg->bytes + pos, &fn_ptr, sizeof(fn_ptr)); + pos += sizeof(fn_ptr); } } @@ -1691,26 +1957,8 @@ namespace ttg_parsec { auto local_end = keylist_sorted.end(); ttg::SplitMetadataDescriptor descr; - auto iovs = descr.get_data(*const_cast(&value)); - int32_t num_iovs = std::distance(std::begin(iovs), std::end(iovs)); - std::vector>> memregs; - memregs.reserve(num_iovs); - - /* register all iovs so the registration can be reused */ - for (auto &&iov : iovs) { - parsec_ce_mem_reg_handle_t lreg; - size_t lreg_size; - parsec_ce.mem_register(iov.data, PARSEC_MEM_TYPE_NONCONTIGUOUS, iov.num_bytes, parsec_datatype_int8_t, - iov.num_bytes, &lreg, &lreg_size); - /* TODO: use a static function for deregistration here? */ - memregs.push_back(std::make_pair(static_cast(lreg_size), - /* TODO: this assumes that parsec_ce_mem_reg_handle_t is void* */ - std::shared_ptr{lreg, [](void *ptr) { - parsec_ce_mem_reg_handle_t memreg = - (parsec_ce_mem_reg_handle_t)ptr; - parsec_ce.mem_unregister(&memreg); - }})); - } + std::vector>> memregs = splitmd_get_iovs(value, descr); + size_t num_iovs = memregs.size(); using msg_t = detail::msg_t; auto &world_impl = world.impl(); @@ -1763,47 +2011,103 @@ namespace ttg_parsec { parsec_ce_tag_t cbtag = reinterpret_cast(&detail::get_remote_complete_cb); std::memcpy(msg->bytes + pos, &cbtag, sizeof(cbtag)); pos += sizeof(cbtag); + splitmd_pack_reghandles(memregs, msg.get(), pos); + tp->tdm.module->outgoing_message_start(tp, owner, NULL); + tp->tdm.module->outgoing_message_pack(tp, owner, NULL, NULL, 0); + parsec_ce.send_am(&parsec_ce, world_impl.parsec_ttg_tag(), owner, static_cast(msg.get()), + sizeof(msg_header_t) + pos); + } + /* handle local keys */ + broadcast_arg_local(local_begin, local_end, value); + } else { + /* handle local keys */ + broadcast_arg_local(keylist.begin(), keylist.end(), value); + } + } - /** - * pack the registration handles - * memory layout: [, ...] - */ - int idx = 0; - for (auto &&iov : iovs) { - // auto [lreg_size, lreg_ptr] = memregs[idx]; - int32_t lreg_size; - std::shared_ptr lreg_ptr; - std::tie(lreg_size, lreg_ptr) = memregs[idx]; - std::memcpy(msg->bytes + pos, &lreg_size, sizeof(lreg_size)); - pos += sizeof(lreg_size); - std::memcpy(msg->bytes + pos, lreg_ptr.get(), lreg_size); - pos += lreg_size; - /* create a function that will be invoked upon RMA completion at the target */ - std::shared_ptr lreg_ptr_v = lreg_ptr; - /* mark another reader on the copy */ - copy = detail::register_data_copy(copy, nullptr, true); - std::function *fn = new std::function([=]() mutable { - /* shared_ptr of value and registration captured by value so resetting - * them here will eventually release the memory/registration */ - detail::release_data_copy(copy); - lreg_ptr_v.reset(); - }); - std::intptr_t fn_ptr{reinterpret_cast(fn)}; - std::memcpy(msg->bytes + pos, &fn_ptr, sizeof(fn_ptr)); - pos += sizeof(fn_ptr); - ++idx; + template + std::enable_if_t && !std::is_void_v> && + ttg::has_split_metadata>::value, + void> + splitmd_rangecast_arg(const ttg::span> &ranges, const Value &value) { + using valueT = std::tuple_element_t; + auto world = ttg_default_execution_context(); + int rank = world.rank(); + auto rankset = ttg::bitset(world.size()); + /* flag every rank for which we have a key */ + for (const auto& range : ranges) { + for (const auto& key : range) { + rankset.set(keymap(key)); + } + } + bool have_remote = rankset.popcnt() > 1 || (rankset.popcnt() == 1 && !rankset[rank]); + + if (have_remote) { + using decvalueT = std::decay_t; + + ttg::SplitMetadataDescriptor descr; + std::vector>> memregs = splitmd_get_iovs(value, descr); + size_t num_iovs = memregs.size(); + + using msg_t = detail::msg_t; + auto &world_impl = world.impl(); + std::unique_ptr msg = std::make_unique(get_instance_id(), world_impl.taskpool()->taskpool_id, + msg_header_t::MSG_SET_ARG, i); + auto metadata = descr.get_metadata(value); + size_t metadata_size = sizeof(metadata); + + ttg_data_copy_t *copy; + copy = detail::find_copy_in_task(parsec_ttg_caller, &value); + assert(nullptr != copy); + + parsec_taskpool_t *tp = world_impl.taskpool(); + size_t pos = 0; + for (int owner = 0; owner < world.size(); ++owner) { + if (rank == owner || !rankset[owner]) continue; + + int num_ranges = 0; + for (const auto& range : ranges) { + /* pack the key range objects */ + pos = pack(range, msg->bytes, pos); + ++num_ranges; } + msg->tt_id.num_keys = -num_ranges; // mark as keyrange + + /* pack the metadata */ + std::memcpy(msg->bytes + pos, &metadata, metadata_size); + pos += metadata_size; + /* pack the local rank */ + int rank = world.rank(); + std::memcpy(msg->bytes + pos, &rank, sizeof(rank)); + pos += sizeof(rank); + /* pack the number of iovecs */ + std::memcpy(msg->bytes + pos, &num_iovs, sizeof(num_iovs)); + pos += sizeof(num_iovs); + + /* TODO: at the moment, the tag argument to parsec_ce.get() is treated as a + * raw function pointer instead of a preregistered AM tag, so play that game. + * Once this is fixed in PaRSEC we need to use parsec_ttg_rma_tag instead! */ + parsec_ce_tag_t cbtag = reinterpret_cast(&detail::get_remote_complete_cb); + std::memcpy(msg->bytes + pos, &cbtag, sizeof(cbtag)); + pos += sizeof(cbtag); + splitmd_pack_reghandles(memregs, msg.get(), pos); tp->tdm.module->outgoing_message_start(tp, owner, NULL); tp->tdm.module->outgoing_message_pack(tp, owner, NULL, NULL, 0); parsec_ce.send_am(&parsec_ce, world_impl.parsec_ttg_tag(), owner, static_cast(msg.get()), sizeof(msg_header_t) + pos); } /* handle local keys */ - broadcast_arg_local(local_begin, local_end, value); + if (rankset[world.rank()]) { + for (const auto& range : ranges) { + broadcast_arg_local(range.begin(), range.end(), value); + } + } } else { - /* handle local keys */ - broadcast_arg_local(keylist.begin(), keylist.end(), value); + for (const auto& range : ranges) { + broadcast_arg_local(range.begin(), range.end(), value); + } } + } // Used by invoke to set all arguments associated with a task @@ -2091,9 +2395,17 @@ namespace ttg_parsec { broadcast_arg(keylist, value); } }; + auto rangecast_callback = [this](const ttg::span> &keylist, + const valueT &value) { + if constexpr (ttg::has_split_metadata>::value) { + splitmd_rangecast_arg(keylist, value); + } else { + rangecast_arg(keylist, value); + } + }; auto setsize_callback = [this](const keyT &key, std::size_t size) { set_argstream_size(key, size); }; auto finalize_callback = [this](const keyT &key) { finalize_argstream(key); }; - input.set_callback(send_callback, move_callback, broadcast_callback, setsize_callback, finalize_callback); + input.set_callback(send_callback, move_callback, broadcast_callback, setsize_callback, finalize_callback, rangecast_callback); } ////////////////////////////////////////////////////////////////// // case 2: nonvoid key, void value, mixed inputs diff --git a/ttg/ttg/terminal.h b/ttg/ttg/terminal.h index 7cea6003e6..a4329b7f35 100644 --- a/ttg/ttg/terminal.h +++ b/ttg/ttg/terminal.h @@ -7,6 +7,7 @@ #include "ttg/base/terminal.h" #include "ttg/fwd.h" +#include "ttg/keyrange.h" #include "ttg/util/demangle.h" #include "ttg/util/meta.h" #include "ttg/util/trace.h" @@ -85,6 +86,7 @@ namespace ttg { using send_callback_type = meta::detail::send_callback_t>; using move_callback_type = meta::detail::move_callback_t>; using broadcast_callback_type = meta::detail::broadcast_callback_t>; + using rangecast_callback_type = meta::detail::rangecast_callback_t>; using setsize_callback_type = typename base_type::setsize_callback_type; using finalize_callback_type = typename base_type::finalize_callback_type; static constexpr bool is_an_input_terminal = true; @@ -93,6 +95,7 @@ namespace ttg { send_callback_type send_callback; move_callback_type move_callback; broadcast_callback_type broadcast_callback; + rangecast_callback_type rangecast_callback; // No moving, copying, assigning permitted In(In &&other) = delete; @@ -110,10 +113,12 @@ namespace ttg { void set_callback(const send_callback_type &send_callback, const move_callback_type &move_callback, const broadcast_callback_type &bcast_callback = broadcast_callback_type{}, const setsize_callback_type &setsize_callback = setsize_callback_type{}, - const finalize_callback_type &finalize_callback = finalize_callback_type{}) { + const finalize_callback_type &finalize_callback = finalize_callback_type{}, + const rangecast_callback_type &rangecast_callback = rangecast_callback_type{}) { this->send_callback = send_callback; this->move_callback = move_callback; this->broadcast_callback = bcast_callback; + this->rangecast_callback = rangecast_callback; base_type::set_callback(setsize_callback, finalize_callback); } @@ -157,7 +162,9 @@ namespace ttg { // An optimized implementation will need a separate callback for broadcast // with a specific value for rangeT template - std::enable_if_t, void> broadcast(const rangeT &keylist, const Value &value) { + std::enable_if_t && + !meta::is_iterable_of_v>, void> + broadcast(const rangeT &keylist, const Value &value) { if (broadcast_callback) { if constexpr (ttg::meta::is_iterable_v) { broadcast_callback(ttg::span(&(*std::begin(keylist)), std::distance(std::begin(keylist), std::end(keylist))), @@ -176,29 +183,10 @@ namespace ttg { } } - template - std::enable_if_t, void> broadcast(const rangeT &keylist, Value &&value) { - const Value &v = value; - if (broadcast_callback) { - if constexpr (ttg::meta::is_iterable_v) { - broadcast_callback( - ttg::span(&(*std::begin(keylist)), std::distance(std::begin(keylist), std::end(keylist))), v); - } else { - /* got something we cannot iterate over (single element?) so put one element in the span */ - broadcast_callback(ttg::span(&keylist, 1), v); - } - } else { - if constexpr (ttg::meta::is_iterable_v) { - for (auto &&key : keylist) send(key, v); - } else { - /* got something we cannot iterate over (single element?) so put one element in the span */ - broadcast_callback(ttg::span(&keylist, 1), v); - } - } - } - template - std::enable_if_t, void> broadcast(const rangeT &keylist) { + std::enable_if_t && + !meta::is_iterable_of_v>, void> + broadcast(const rangeT &keylist) { if (broadcast_callback) { if constexpr (ttg::meta::is_iterable_v) { broadcast_callback( @@ -216,6 +204,55 @@ namespace ttg { } } } + + /** + * Overload for key ranges + */ + template + std::enable_if_t, void> broadcast(const ttg::detail::LinearKeyRange &range, const Value &value) { + if (rangecast_callback) { + rangecast_callback(ttg::span(range, 1), value); + } else { + for (const auto &key : range) send(key, value); + } + } + + template + std::enable_if_t, void> broadcast(const ttg::detail::LinearKeyRange &range) { + if (rangecast_callback) { + rangecast_callback(ttg::span(range, 1)); + } else { + for (const auto &key : range) sendk(key); + } + } + + template + std::enable_if_t && + meta::is_iterable_of_v>, void> + broadcast(const rangeT &rangelist, const Value &value) { + if (rangecast_callback) { + rangecast_callback(ttg::span(&(*std::begin(rangelist)), std::distance(std::begin(rangelist), std::end(rangelist))), + value); + } else { + for (const auto& range : rangelist) + for (const auto &key : range) + send(key, value); + } + } + + template + std::enable_if_t && + meta::is_iterable_of_v>, void> + broadcast(const rangeT &rangelist) { + if (rangecast_callback) { + rangecast_callback(ttg::span(&(*std::begin(rangelist)), std::distance(std::begin(rangelist), std::end(rangelist)))); + } else { + for (const auto& range : rangelist) + for (const auto &key : range) + sendk(key); + } + } + }; template diff --git a/ttg/ttg/util/bits.h b/ttg/ttg/util/bits.h new file mode 100644 index 0000000000..dee2c2950f --- /dev/null +++ b/ttg/ttg/util/bits.h @@ -0,0 +1,84 @@ +#ifndef TTG_BITS_H +#define TTG_BITS_H + +#ifdef __cpp_lib_bitops +#include +#endif // __cpp_lib_bitops + +#include + +/** + * Implement some functions of the header introduced in C++20. + * + * Also provides a dynamically allocated bitset. + */ + +namespace ttg { + +#ifdef __cpp_lib_bitops + template< class T > + constexpr int popcount( T x ) noexcept { + return std::popcount(x); + } +#else + template< class T > + constexpr int popcount( T x ) noexcept { + int res = 0; + for (int i = 0; i < sizeof(T)*8; ++i) { + res += !!(x & (1< m_storage; + mutable ssize_t m_popcnt = -1; + constexpr static size_t storage_size = sizeof(storage_type); + + public: + bitset(size_t size) : m_storage((size+storage_size-1)/storage_size) + { } + + void set(size_t i) noexcept { + m_storage[i/storage_size] |= 1<<(i%storage_size); + m_popcnt = -1; + } + + bool get(size_t i) const noexcept { + return !!(m_storage[i/storage_size] & (1<<(i%storage_size))); + } + + bool operator[](size_t i) const noexcept { + return get(i); + } + + size_t size() const noexcept { + return m_storage.size()*storage_size; + } + + size_t popcnt() const noexcept { + if (m_popcnt == -1) { + m_popcnt = 0; + for (const auto& v : m_storage) { + m_popcnt += ttg::popcount(v); + } + } + return m_popcnt; + } + + void clear() noexcept { + size_t size = m_storage.size(); + m_storage.clear(); + m_storage.resize(size); + } + }; + + +} // namespace ttg + +#endif // TTG_BITS_H diff --git a/ttg/ttg/util/meta.h b/ttg/ttg/util/meta.h index 179bf2dbc5..79c2cfcca9 100644 --- a/ttg/ttg/util/meta.h +++ b/ttg/ttg/util/meta.h @@ -11,6 +11,11 @@ namespace ttg { class Void; + namespace detail { + template + struct LinearKeyRange; + } // namespace detail + namespace meta { #if __cplusplus >= 201703L @@ -650,6 +655,31 @@ namespace ttg { template using broadcast_callback_t = typename broadcast_callback::type; + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // rangecast_callback_t = std::function, protected against void key or value + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + struct rangecast_callback; + template + struct rangecast_callback && !is_void_v>> { + using type = std::function> &, const Value &)>; + }; + template + struct rangecast_callback && is_void_v>> { + using type = std::function> &)>; + }; + template + struct rangecast_callback && !is_void_v>> { + using type = std::function; + }; + template + struct rangecast_callback && is_void_v>> { + using type = std::function; + }; + template + using rangecast_callback_t = typename rangecast_callback::type; + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // setsize_callback_t = std::function protected against void key //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -742,6 +772,21 @@ namespace ttg { template constexpr bool is_iterable_v = is_iterable::value; + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // check whether a type is iterable yielding a specified value type + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + struct is_iterable_of : std::false_type {}; + + // this gets used only when we can call std::begin() and std::end() on that type + template + struct is_iterable_of()))>> + && std::is_same_v()))>>>> + : std::true_type {}; + + template + constexpr bool is_iterable_of_v = is_iterable_of::value; + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// // check whether a Callable is invocable with the arguments given as a typelist //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -755,6 +800,89 @@ namespace ttg { constexpr bool is_invocable_typelist_r_v> = std::is_invocable_r_v; + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // check whether a type has operator++ defined + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + struct has_pre_increment : std::false_type + { }; + + template + struct has_pre_increment()))>> + : std::true_type + { }; + + template + constexpr bool has_pre_increment_v = has_pre_increment::value; + + template + struct has_post_increment : std::false_type + { }; + + template + struct has_post_increment())++)>> + : std::true_type + { }; + + template + constexpr bool has_post_increment_v = has_post_increment::value; + + template + constexpr bool has_increment_v = has_pre_increment::value || has_post_increment::value; + + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // check whether a type has operator- and operator+ defined + //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + template + struct has_difference : std::false_type + { }; + + template + struct has_difference()) - (std::declval()))>> + : std::true_type + { }; + + template + constexpr bool has_difference_v = has_difference::value; + + template + struct has_addition : std::false_type + { }; + + template + struct has_addition()) + (std::declval()))>> + : std::true_type + { }; + + template + constexpr bool has_addition_v = has_addition::value; + + template + struct has_compound_addition : std::false_type + { }; + + template + struct has_compound_addition()) += (std::declval()))>> + : std::true_type + { }; + + template + constexpr bool has_compound_addition_v = has_compound_addition::value; + + /* Check if a type is comparable with itself. + * TODO: use of C++20 is_comparable once it's available */ + template + struct is_comparable : std::false_type + { }; + + template + struct is_comparable()) == (std::declval()))>> + : std::true_type + { }; + + template + constexpr bool is_comparable_v = is_comparable::value; + } // namespace meta } // namespace ttg