Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into feat/snakes-on-a-by…
Browse files Browse the repository at this point in the history
…te-transport
  • Loading branch information
thorstenhater committed May 22, 2024
2 parents d45b726 + 689eea3 commit 90622cb
Show file tree
Hide file tree
Showing 54 changed files with 6,259 additions and 90 deletions.
2 changes: 2 additions & 0 deletions arbor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ set(arbor_sources
morph/segment_tree.cpp
morph/stitch.cpp
merge_events.cpp
network.cpp
network_impl.cpp
simulation.cpp
partition_load_balance.cpp
profile/clock.cpp
Expand Down
53 changes: 35 additions & 18 deletions arbor/communication/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "connection.hpp"
#include "distributed_context.hpp"
#include "execution_context.hpp"
#include "network_impl.hpp"
#include "profile/profiler_macro.hpp"
#include "threading/threading.hpp"
#include "util/partition.hpp"
Expand All @@ -24,14 +25,12 @@

namespace arb {

communicator::communicator(const recipe& rec,
const domain_decomposition& dom_dec,
execution_context& ctx): num_total_cells_{rec.num_cells()},
num_local_cells_{dom_dec.num_local_cells()},
num_local_groups_{dom_dec.num_groups()},
num_domains_{(cell_size_type) ctx.distributed->size()},
distributed_{ctx.distributed},
thread_pool_{ctx.thread_pool} {}
communicator::communicator(const recipe& rec, const domain_decomposition& dom_dec, context ctx):
num_total_cells_{rec.num_cells()},
num_local_cells_{dom_dec.num_local_cells()},
num_local_groups_{dom_dec.num_groups()},
num_domains_{(cell_size_type)ctx->distributed->size()},
ctx_(std::move(ctx)) {}

constexpr inline
bool is_external(cell_gid_type c) {
Expand All @@ -55,7 +54,7 @@ cell_member_type global_cell_of(const cell_member_type& c) {
return {c.gid | msb, c.index};
}

void communicator::update_connections(const connectivity& rec,
void communicator::update_connections(const recipe& rec,
const domain_decomposition& dom_dec,
const label_resolution_map& source_resolution_map,
const label_resolution_map& target_resolution_map) {
Expand All @@ -67,6 +66,9 @@ void communicator::update_connections(const connectivity& rec,
index_divisions_.clear();
PL();

// Construct connections from high-level specification
auto generated_connections = generate_connections(rec, ctx_, dom_dec);

// Make a list of local cells' connections
// -> gid_connections
// Count the number of local connections (i.e. connections terminating on this domain)
Expand Down Expand Up @@ -114,9 +116,18 @@ void communicator::update_connections(const connectivity& rec,
}
part_ext_connections.push_back(gid_ext_connections.size());
}
for (const auto& c: generated_connections) {
auto sgid = c.source.gid;
if (sgid >= num_total_cells_) {
throw arb::bad_connection_source_gid(c.source.gid, sgid, num_total_cells_);
}
const auto src = dom_dec.gid_domain(sgid);
src_domains.push_back(src);
src_counts[src]++;
}

util::make_partition(connection_part_, src_counts);
auto n_cons = gid_connections.size();
auto n_cons = gid_connections.size() + generated_connections.size();
auto n_ext_cons = gid_ext_connections.size();
PL();

Expand All @@ -132,6 +143,7 @@ void communicator::update_connections(const connectivity& rec,
auto target_resolver = resolver(&target_resolution_map);
for (const auto index: util::make_span(num_local_cells_)) {
const auto tgt_gid = gids[index];
const auto iod = dom_dec.index_on_domain(tgt_gid);
auto source_resolver = resolver(&source_resolution_map);
for (const auto cidx: util::make_span(part_connections[index], part_connections[index+1])) {
const auto& conn = gid_connections[cidx];
Expand All @@ -141,18 +153,23 @@ void communicator::update_connections(const connectivity& rec,
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
auto offset = offsets[*src_domain]++;
++src_domain;
connections[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, index};
connections[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, iod};
}
for (const auto cidx: util::make_span(part_ext_connections[index], part_ext_connections[index+1])) {
const auto& conn = gid_ext_connections[cidx];
auto src = global_cell_of(conn.source);
auto src_gid = conn.source.rid;
if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(tgt_gid, src_gid);
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
ext_connections[ext] = {src, tgt_lid, conn.weight, conn.delay, index};
ext_connections[ext] = {src, tgt_lid, conn.weight, conn.delay, iod};
++ext;
}
}
for (const auto& c: generated_connections) {
auto offset = offsets[*src_domain]++;
++src_domain;
connections[offset] = c;
}
PL();

PE(init:communicator:update:index);
Expand All @@ -167,7 +184,7 @@ void communicator::update_connections(const connectivity& rec,
// Sort the connections for each domain.
// This is num_domains_ independent sorts, so it can be parallelized trivially.
const auto& cp = connection_part_;
threading::parallel_for::apply(0, num_domains_, thread_pool_.get(),
threading::parallel_for::apply(0, num_domains_, ctx_->thread_pool.get(),
[&](cell_size_type i) {
util::sort(util::subrange_view(connections, cp[i], cp[i+1]));
});
Expand All @@ -193,7 +210,7 @@ time_type communicator::min_delay() {
res = std::accumulate(ext_connections_.delays.begin(), ext_connections_.delays.end(),
res,
[](auto&& acc, time_type del) { return std::min(acc, del); });
res = distributed_->min(res);
res = ctx_->distributed->min(res);
return res;
}

Expand All @@ -206,7 +223,7 @@ communicator::exchange(std::vector<spike> local_spikes) {

PE(communication:exchange:gather);
// global all-to-all to gather a local copy of the global spike list on each node.
auto global_spikes = distributed_->gather_spikes(local_spikes);
auto global_spikes = ctx_->distributed->gather_spikes(local_spikes);
num_spikes_ += global_spikes.size();
PL();

Expand All @@ -217,7 +234,7 @@ communicator::exchange(std::vector<spike> local_spikes) {
local_spikes.end(),
[this] (const auto& s) { return !remote_spike_filter_(s); }));
}
auto remote_spikes = distributed_->remote_gather_spikes(local_spikes);
auto remote_spikes = ctx_->distributed->remote_gather_spikes(local_spikes);
PL();

PE(communication:exchange:gather:remote:post_process);
Expand All @@ -231,8 +248,8 @@ communicator::exchange(std::vector<spike> local_spikes) {
}

void communicator::set_remote_spike_filter(const spike_predicate& p) { remote_spike_filter_ = p; }
void communicator::remote_ctrl_send_continue(const epoch& e) { distributed_->remote_ctrl_send_continue(e); }
void communicator::remote_ctrl_send_done() { distributed_->remote_ctrl_send_done(); }
void communicator::remote_ctrl_send_continue(const epoch& e) { ctx_->distributed->remote_ctrl_send_continue(e); }
void communicator::remote_ctrl_send_done() { ctx_->distributed->remote_ctrl_send_done(); }

// Given
// * a set of connections and an index into the set
Expand Down
13 changes: 6 additions & 7 deletions arbor/communication/communicator.hpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#pragma once

#include <vector>
#include <unordered_set>

#include <arbor/export.hpp>
#include <arbor/common_types.hpp>
#include <arbor/context.hpp>
#include <arbor/domain_decomposition.hpp>
#include <arbor/export.hpp>
#include <arbor/recipe.hpp>
#include <arbor/spike.hpp>

Expand Down Expand Up @@ -40,7 +40,7 @@ class ARB_ARBOR_API communicator {

explicit communicator(const recipe& rec,
const domain_decomposition& dom_dec,
execution_context& ctx);
context ctx);

/// The range of event queues that belong to cells in group i.
std::pair<cell_size_type, cell_size_type> group_queue_range(cell_size_type i);
Expand Down Expand Up @@ -78,7 +78,7 @@ class ARB_ARBOR_API communicator {
void remote_ctrl_send_continue(const epoch&);
void remote_ctrl_send_done();

void update_connections(const connectivity& rec,
void update_connections(const recipe& rec,
const domain_decomposition& dom_dec,
const label_resolution_map& source_resolution_map,
const label_resolution_map& target_resolution_map);
Expand All @@ -98,7 +98,7 @@ class ARB_ARBOR_API communicator {
for (const auto& con: cons) {
idx_on_domain.push_back(con.index_on_domain);
srcs.push_back(con.source);
dests.push_back(con.destination);
dests.push_back(con.target);
weights.push_back(con.weight);
delays.push_back(con.delay);
}
Expand Down Expand Up @@ -136,10 +136,9 @@ class ARB_ARBOR_API communicator {
// Currently we have no partitions/indices/acceleration structures
connection_list ext_connections_;

distributed_context_handle distributed_;
task_system_handle thread_pool_;
std::uint64_t num_spikes_ = 0u;
std::uint64_t num_local_events_ = 0u;
context ctx_;
};

} // namespace arb
185 changes: 185 additions & 0 deletions arbor/communication/distributed_for_each.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
#pragma once

#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <memory>
#include <type_traits>
#include <utility>

#include "distributed_context.hpp"
#include "util/range.hpp"

namespace arb {

namespace impl {
template <class FUNC, typename... T, std::size_t... Is>
void for_each_in_tuple(FUNC&& func, std::tuple<T...>& t, std::index_sequence<Is...>) {
(func(Is, std::get<Is>(t)), ...);
}

template <class FUNC, typename... T>
void for_each_in_tuple(FUNC&& func, std::tuple<T...>& t) {
for_each_in_tuple(func, t, std::index_sequence_for<T...>());
}

template <class FUNC, typename... T1, typename... T2, std::size_t... Is>
void for_each_in_tuple_pair(FUNC&& func,
std::tuple<T1...>& t1,
std::tuple<T2...>& t2,
std::index_sequence<Is...>) {
(func(Is, std::get<Is>(t1), std::get<Is>(t2)), ...);
}

template <class FUNC, typename... T1, typename... T2>
void for_each_in_tuple_pair(FUNC&& func, std::tuple<T1...>& t1, std::tuple<T2...>& t2) {
for_each_in_tuple_pair(func, t1, t2, std::index_sequence_for<T1...>());
}

} // namespace impl


/*
* Collective operation, calling func on args supplied by each rank exactly once. The order of calls
* is unspecified. Requires
*
* - Item = util::range<ARGS>::value_type to be identical across all ranks
* - Item is trivially_copyable
* - Alignment of Item must not exceed std::max_align_t
* - func to be a callable type with signature
* void func(util::range<Item*>...)
* - func must not modify contents of range
* - All ranks in distributed must call this collectively.
*/
template <typename FUNC, typename... ARGS>
void distributed_for_each(FUNC&& func,
const distributed_context& distributed,
const util::range<ARGS>&... args) {

static_assert(sizeof...(args) > 0);
auto arg_tuple = std::forward_as_tuple(args...);

struct vec_info {
std::size_t offset; // offset in bytes
std::size_t size; // size in bytes
};

std::array<vec_info, sizeof...(args)> info;
std::size_t buffer_size = 0;

// Compute offsets in bytes for each vector when placed in common buffer
{
std::size_t offset = info.size() * sizeof(vec_info);
impl::for_each_in_tuple(
[&](std::size_t i, auto&& vec) {
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
static_assert(std::is_trivially_copyable_v<T>);
static_assert(alignof(std::max_align_t) >= alignof(T));
static_assert(alignof(std::max_align_t) % alignof(T) == 0);

// make sure alignment of offset fulfills requirement
const auto alignment_excess = offset % alignof(T);
offset += alignment_excess > 0 ? alignof(T) - (alignment_excess) : 0;

const auto size_in_bytes = vec.size() * sizeof(T);

info[i].size = size_in_bytes;
info[i].offset = offset;

buffer_size = offset + size_in_bytes;
offset += size_in_bytes;
},
arg_tuple);
}

// compute maximum buffer size between ranks, such that we only allocate once
const std::size_t max_buffer_size = distributed.max(buffer_size);

std::tuple<util::range<typename std::remove_reference_t<decltype(args)>::value_type*>...>
ranges;

if (max_buffer_size == info.size() * sizeof(vec_info)) {
// if all empty, call function with empty ranges for each step and exit
impl::for_each_in_tuple_pair(
[&](std::size_t i, auto&& vec, auto&& r) {
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
r = util::range<T*>(nullptr, nullptr);
},
arg_tuple,
ranges);

for (int step = 0; step < distributed.size(); ++step) { std::apply(func, ranges); }
return;
}

// use malloc for std::max_align_t alignment
auto deleter = [](char* ptr) { std::free(ptr); };
std::unique_ptr<char[], void (*)(char*)> buffer((char*)std::malloc(max_buffer_size), deleter);
std::unique_ptr<char[], void (*)(char*)> recv_buffer(
(char*)std::malloc(max_buffer_size), deleter);

// copy offset and size info to front of buffer
std::memcpy(buffer.get(), info.data(), info.size() * sizeof(vec_info));

// copy each vector to each location in buffer
impl::for_each_in_tuple(
[&](std::size_t i, auto&& vec) {
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
std::copy(vec.begin(), vec.end(), (T*)(buffer.get() + info[i].offset));
},
arg_tuple);


const auto my_rank = distributed.id();
const auto left_rank = my_rank == 0 ? distributed.size() - 1 : my_rank - 1;
const auto right_rank = my_rank == distributed.size() - 1 ? 0 : my_rank + 1;

// exchange buffer in ring pattern and apply function at each step
for (int step = 0; step < distributed.size() - 1; ++step) {
// always expect to recieve the max size but send actual size. MPI_recv only expects a max
// size, not the actual size.
const auto current_info = (const vec_info*)buffer.get();

auto request = distributed.send_recv_nonblocking(max_buffer_size,
recv_buffer.get(),
right_rank,
current_info[info.size() - 1].offset + current_info[info.size() - 1].size,
buffer.get(),
left_rank,
0);

// update ranges
impl::for_each_in_tuple_pair(
[&](std::size_t i, auto&& vec, auto&& r) {
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
r = util::range<T*>((T*)(buffer.get() + current_info[i].offset),
(T*)(buffer.get() + current_info[i].offset + current_info[i].size));
},
arg_tuple,
ranges);

// call provided function with ranges pointing to current buffer
std::apply(func, ranges);

request.finalize();
buffer.swap(recv_buffer);
}

// final step does not require any exchange
const auto current_info = (const vec_info*)buffer.get();
impl::for_each_in_tuple_pair(
[&](std::size_t i, auto&& vec, auto&& r) {
using T = typename std::remove_reference_t<decltype(vec)>::value_type;
r = util::range<T*>((T*)(buffer.get() + current_info[i].offset),
(T*)(buffer.get() + current_info[i].offset + current_info[i].size));
},
arg_tuple,
ranges);

// call provided function with ranges pointing to current buffer
std::apply(func, ranges);
}

} // namespace arb
Loading

0 comments on commit 90622cb

Please sign in to comment.