diff --git a/arbor/communication/communicator.cpp b/arbor/communication/communicator.cpp index bb206f7989..8a4e7be0a6 100644 --- a/arbor/communication/communicator.cpp +++ b/arbor/communication/communicator.cpp @@ -55,125 +55,218 @@ cell_member_type global_cell_of(const cell_member_type& c) { return {c.gid | msb, c.index}; } -void communicator::update_connections(const connectivity& rec, - const domain_decomposition& dom_dec, - const label_resolution_map& source_resolution_map, - const label_resolution_map& target_resolution_map) { - PE(init:communicator:update:clear); - // Forget all lingering information - connections_.clear(); - connection_part_.clear(); - index_divisions_.clear(); - PL(); - - // Make a list of local cells' connections - // -> gid_connections - // Count the number of local connections (i.e. connections terminating on this domain) - // -> n_cons: scalar - // Calculate and store domain id of the presynaptic cell on each local connection - // -> src_domains: array with one entry for every local connection - // Also the count of presynaptic sources from each domain - // -> src_counts: array with one entry for each domain - - // Record all the gid in a flat vector. - - PE(init:communicator:update:collect_gids); - std::vector gids; gids.reserve(num_local_cells_); - for (const auto& g: dom_dec.groups()) util::append(gids, g.gids); - PL(); - - // Build the connection information for local cells. - PE(init:communicator:update:gid_connections); +// Build local(ie Arbor to Arbor) connection list +// Writes +// * connections := [connection] +// * connection_part := [index into connections] +// - such that all connections _from the nth source domain_ are located +// between connections_part[n] and connections_part[n+1] in connections. +// - source domains are the MPI ranks associated with the gid of the source +// of a connection. +// - as the spike buffer is sorted and partitioned by said source domain, we +// can use this to quickly filter spike buffer for spikes relevant to us. +// * index_divisions_ and index_part_ +// - index_part_ is used to map a cell group index to a range of queue indices +// - these indices identify the queue in simulation belonging to cell the nth cell group +// - queue stores incoming events for a cell +// - events are not identical to spikes, but constructed from them +// - this indirection is needed as communicator/simulation is responsible for multiple +// cell groups. +// - index_part_ is a view onto index_divisions_. The latter is not directly used, but is +// the backing data of the former. (Essentially index_part[n] = range(index_div[n], index_div[n+1])) +void update_local_connections(const connectivity& rec, + const domain_decomposition& dec, + const std::vector& gids, + size_t num_total_cells, + size_t num_local_cells, + size_t num_domains, + // Outputs; written into communicator + std::vector& connections_, + std::vector& connection_part_, + std::vector& index_divisions_, + util::partition_view_type>& index_part_, + task_system_handle thread_pool_, + // Mutable state for label resolution. + resolver& target_resolver, + resolver& source_resolver) { + PE(init:communicator:update:local:gid_connections); + // List all connections and partition them by their _target cell's index_ std::vector gid_connections; - std::vector gid_ext_connections; std::vector part_connections; - part_connections.reserve(num_local_cells_); + part_connections.reserve(num_local_cells); part_connections.push_back(0); - std::vector part_ext_connections; - part_ext_connections.reserve(num_local_cells_); - part_ext_connections.push_back(0); + // Map connection _index_ to the id of the source gid's domain. + // eg: + // Our gids [23, 42], indices [0, 1] and #domain 3 + // Connections [ 42 <- 0, 42 <- 1, 23 <- 5, 42 <- 23, 23 <- 1] + // Domains [[0, 1, 2, 3], [4, 5], [...], [23, 42]] + // Thus we get + // Src Domains [ 0, 1, 3, 0] + // Src Counts [ 2, 1, 0, 1] + // Partitition [ 0 2 5 ] std::vector src_domains; - std::vector src_counts(num_domains_); + std::vector src_counts(num_domains); + + // Build the data structures above. for (const auto gid: gids) { - // Local const auto& conns = rec.connections_on(gid); for (const auto& conn: conns) { - const auto sgid = conn.source.gid; - if (sgid >= num_total_cells_) throw arb::bad_connection_source_gid(gid, sgid, num_total_cells_); - const auto src = dom_dec.gid_domain(sgid); + const auto src_gid = conn.source.gid; + if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(gid, src_gid); + if (src_gid >= num_total_cells) throw arb::bad_connection_source_gid(gid, src_gid, num_total_cells); + const auto src = dec.gid_domain(src_gid); src_domains.push_back(src); src_counts[src]++; gid_connections.emplace_back(conn); } part_connections.push_back(gid_connections.size()); - // Remote - const auto& ext_conns = rec.external_connections_on(gid); - for (const auto& conn: ext_conns) { - gid_ext_connections.emplace_back(conn); - } - part_ext_connections.push_back(gid_ext_connections.size()); } + // Construct partitioning of connections on src_domains, thus + // continuing the above example: + // connection_part_ [ 0 2 3 3 4] + // mapping the ranges + // [0-2, 2-3, 3-3, 3-4] + // in the to-be-created connection array. util::make_partition(connection_part_, src_counts); - auto n_cons = gid_connections.size(); - auto n_ext_cons = gid_ext_connections.size(); PL(); // Construct the connections. The loop above gave us the information needed // to do this in place. - // NOTE: The connections are partitioned by the domain of their source gid. - PE(init:communicator:update:connections); - connections_.resize(n_cons); - ext_connections_.resize(n_ext_cons); - auto offsets = connection_part_; // Copy, as we use this as the list of current target indices to write into - std::size_t ext = 0; - auto src_domain = src_domains.begin(); - auto target_resolver = resolver(&target_resolution_map); - for (const auto index: util::make_span(num_local_cells_)) { + PE(init:communicator:update:local:connections); + connections_.resize(gid_connections.size()); + // Copy, as we use this as the list of current target indices to write into + struct offset_t { + std::vector::iterator source; + std::vector offsets; + cell_size_type next() { return offsets[*source++]++; } + }; + + auto offsets = offset_t{src_domains.begin(), connection_part_}; + for (const auto index: util::make_span(num_local_cells)) { const auto tgt_gid = gids[index]; - auto source_resolver = resolver(&source_resolution_map); - for (const auto cidx: util::make_span(part_connections[index], part_connections[index+1])) { + for (const auto cidx: util::make_span(part_connections[index], + part_connections[index+1])) { const auto& conn = gid_connections[cidx]; auto src_gid = conn.source.gid; - if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(tgt_gid, src_gid); auto src_lid = source_resolver.resolve(conn.source); auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target); - auto offset = offsets[*src_domain]++; - ++src_domain; - connections_[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, index}; - } - for (const auto cidx: util::make_span(part_ext_connections[index], part_ext_connections[index+1])) { - const auto& conn = gid_ext_connections[cidx]; - auto src = global_cell_of(conn.source); - auto src_gid = conn.source.rid; - if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(tgt_gid, src_gid); - auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target); - ext_connections_[ext] = {src, tgt_lid, conn.weight, conn.delay, index}; - ++ext; + auto out_idx = offsets.next(); + connections_[out_idx] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, (cell_size_type)index}; } + source_resolver.reset(); } PL(); - PE(init:communicator:update:index); + PE(init:communicator:update:local:index); // Build cell partition by group for passing events to cell groups index_part_ = util::make_partition(index_divisions_, - util::transform_view( - dom_dec.groups(), - [](const group_description& g){ return g.gids.size(); })); + util::transform_view(dec.groups(), + [](const auto& g){ return g.gids.size(); })); PL(); - PE(init:communicator:update:sort_connections); + PE(init:communicator:update:local:sort_connections); // Sort the connections for each domain. // This is num_domains_ independent sorts, so it can be parallelized trivially. - const auto& cp = connection_part_; - threading::parallel_for::apply(0, num_domains_, thread_pool_.get(), + threading::parallel_for::apply(0, num_domains, thread_pool_.get(), [&](cell_size_type i) { - util::sort(util::subrange_view(connections_, cp[i], cp[i+1])); + util::sort(util::subrange_view(connections_, + connection_part_[i], + connection_part_[i+1])); }); + PL(); +} + +// Build lists for the _remote_ connections. No fancy acceleration structures +// are built and the list is globally sorted. +void update_remote_connections(const connectivity& rec, + const domain_decomposition& dec, + const std::vector& gids, + size_t num_total_cells, + size_t num_local_cells, + size_t num_domains, + // Outputs; written into communicator + std::vector& ext_connections_, + // Mutable state for label resolution. + resolver& target_resolver, + resolver& source_resolver) { + PE(init:communicator:update:remote:gid_connections); + std::vector gid_ext_connections; + std::vector part_ext_connections; + part_ext_connections.reserve(num_local_cells); + part_ext_connections.push_back(0); + for (const auto gid: gids) { + const auto& ext_conns = rec.external_connections_on(gid); + for (const auto& conn: ext_conns) { + // NOTE: This might look like a bug, but the _remote id_ is consider locally + // in the remote id space, ie must not be already tagged as remote. + if(is_external(conn.source.rid)) throw arb::source_gid_exceeds_limit(gid, conn.source.rid); + gid_ext_connections.emplace_back(conn); + } + part_ext_connections.push_back(gid_ext_connections.size()); + } + PL(); + + // Construct the connections. The loop above gave us the information needed + // to do this in place. + PE(init:communicator:update:remote:connections); + ext_connections_.resize(gid_ext_connections.size()); + std::size_t ext = 0; + for (const auto index: util::make_span(num_local_cells)) { + const auto tgt_gid = gids[index]; + for (const auto cidx: util::make_span(part_ext_connections[index], + part_ext_connections[index+1])) { + const auto& conn = gid_ext_connections[cidx]; + auto src = global_cell_of(conn.source); + auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target); + ext_connections_[ext] = {src, tgt_lid, conn.weight, conn.delay, (cell_size_type) index}; + ++ext; + } + source_resolver.reset(); + } + PL(); + + PE(init:communicator:update:remote:sort_connections); std::sort(ext_connections_.begin(), ext_connections_.end()); PL(); } +void communicator::update_connections(const connectivity& rec, + const domain_decomposition& dom_dec, + const label_resolution_map& source_resolution_map, + const label_resolution_map& target_resolution_map) { + PE(init:communicator:update:clear); + // Forget all lingering information + connections_.clear(); + connection_part_.clear(); + ext_connections_.clear(); + index_divisions_.clear(); + PL(); + + // Remember list of local gids + PE(init:communicator:update:collect_gids); + std::vector gids; gids.reserve(num_local_cells_); + for (const auto& g: dom_dec.groups()) util::append(gids, g.gids); + PL(); + + // Build resolvers + auto target_resolver = resolver(&target_resolution_map); + auto source_resolver = resolver(&source_resolution_map); + + update_local_connections(rec, dom_dec, gids, + num_total_cells_, num_local_cells_, num_domains_, + connections_, + connection_part_, index_divisions_, + index_part_, + thread_pool_, + target_resolver, source_resolver); + + update_remote_connections(rec, dom_dec, gids, + num_total_cells_, num_local_cells_, num_domains_, + ext_connections_, + target_resolver, source_resolver); +} + std::pair communicator::group_queue_range(cell_size_type i) { arb_assert(i& global_spikes, - std::vector& queues, - const std::vector& external_spikes={}); + void make_event_queues(const gathered_vector& global_spikes, + std::vector& queues, + const std::vector& external_spikes={}); /// Returns the total number of global spikes over the duration of the simulation std::uint64_t num_spikes() const; diff --git a/arbor/domain_decomposition.cpp b/arbor/domain_decomposition.cpp index aa22082120..ad0ff854d9 100644 --- a/arbor/domain_decomposition.cpp +++ b/arbor/domain_decomposition.cpp @@ -7,110 +7,85 @@ #include #include +#include "cell_group_factory.hpp" #include "execution_context.hpp" #include "util/partition.hpp" #include "util/rangeutil.hpp" #include "util/span.hpp" namespace arb { -domain_decomposition::domain_decomposition( - const recipe& rec, - context ctx, - const std::vector& groups) -{ - struct partition_gid_domain { - partition_gid_domain(const gathered_vector& divs, unsigned domains) { - auto rank_part = util::partition_view(divs.partition()); - for (auto rank: count_along(rank_part)) { - for (auto gid: util::subrange_view(divs.values(), rank_part[rank])) { - gid_map[gid] = rank; - } - } - } - int operator()(cell_gid_type gid) const { - return gid_map.at(gid); - } - std::unordered_map gid_map; - }; +domain_decomposition::domain_decomposition(const recipe& rec, + context ctx, + std::vector groups): + num_global_cells_{rec.num_cells()}, + groups_(std::move(groups)) +{ const auto* dist = ctx->distributed.get(); - unsigned num_domains = dist->size(); - int domain_id = dist->id(); - cell_size_type num_global_cells = rec.num_cells(); - const bool has_gpu = ctx->gpu->has_gpu(); + num_domains_ = dist->size(); + domain_id_ = dist->id(); + // Collect and do a first check on the local gid set + // * Are all GJ connected cells in the same group std::vector local_gids; - for (const auto& g: groups) { - if (g.backend == backend_kind::gpu && !has_gpu) { - throw invalid_backend(domain_id); - } - if (g.backend == backend_kind::gpu && g.kind != cell_kind::cable) { - throw incompatible_backend(domain_id, g.kind); - } - + for (const auto& g: groups_) { + // Check whether GPU is supported and bail if not + // TODO: This would benefit from generalisation; ie + if (!has_backend(ctx, g.backend)) throw invalid_backend(domain_id_, g.backend); + if (!cell_kind_supported(g.kind, g.backend, *ctx)) throw incompatible_backend(domain_id_, g.kind, g.backend); + // Check GJ cliques. std::unordered_set gid_set(g.gids.begin(), g.gids.end()); for (const auto& gid: g.gids) { - if (gid >= num_global_cells) { - throw out_of_bounds(gid, num_global_cells); - } + if (gid >= num_global_cells_) throw out_of_bounds(gid, num_global_cells_); for (const auto& gj: rec.gap_junctions_on(gid)) { - if (!gid_set.count(gj.peer.gid)) { - throw invalid_gj_cell_group(gid, gj.peer.gid); - } + if (!gid_set.count(gj.peer.gid)) throw invalid_gj_cell_group(gid, gj.peer.gid); } + local_gids.push_back(gid); } - local_gids.insert(local_gids.end(), g.gids.begin(), g.gids.end()); } - cell_size_type num_local_cells = local_gids.size(); + num_local_cells_ = local_gids.size(); + // MPI: Build global gid list incl their partition into domains. auto global_gids = dist->gather_gids(local_gids); - if (global_gids.size() != num_global_cells) { - throw invalid_sum_local_cells(global_gids.size(), num_global_cells); - } + // Sanity check of global gid list + // * missing GIDs? + // * too many GIDs? + // * duplicate GIDs? + // * skipped GIDa? auto global_gid_vals = global_gids.values(); util::sort(global_gid_vals); for (unsigned i = 1; i < global_gid_vals.size(); ++i) { if (global_gid_vals[i] == global_gid_vals[i-1]) { throw duplicate_gid(global_gid_vals[i]); } + if (global_gid_vals[i] > global_gid_vals[i-1] + 1) { + throw skipped_gid(global_gid_vals[i], global_gid_vals[i-1]); + } } - num_domains_ = num_domains; - domain_id_ = domain_id; - num_local_cells_ = num_local_cells; - num_global_cells_ = num_global_cells; - groups_ = groups; - gid_domain_ = partition_gid_domain(global_gids, num_domains); + // Build map of local gid -> domain id (aka MPI rank) + auto rank_part = util::partition_view(global_gids.partition()); + for (auto rank: count_along(rank_part)) { + for (auto gid: util::subrange_view(global_gids.values(), rank_part[rank])) { + gid_map_[gid] = rank; + } + } } -int domain_decomposition::gid_domain(cell_gid_type gid) const { - return gid_domain_(gid); -} +int domain_decomposition::gid_domain(cell_gid_type gid) const { return gid_map_.at(gid); } -int domain_decomposition::num_domains() const { - return num_domains_; -} +int domain_decomposition::num_domains() const { return num_domains_; } -int domain_decomposition::domain_id() const { - return domain_id_; -} +int domain_decomposition::domain_id() const { return domain_id_; } -cell_size_type domain_decomposition::num_local_cells() const { - return num_local_cells_; -} +cell_size_type domain_decomposition::num_local_cells() const { return num_local_cells_; } -cell_size_type domain_decomposition::num_global_cells() const { - return num_global_cells_; -} +cell_size_type domain_decomposition::num_global_cells() const { return num_global_cells_; } -cell_size_type domain_decomposition::num_groups() const { - return groups_.size(); -} +cell_size_type domain_decomposition::num_groups() const { return groups_.size(); } -const std::vector& domain_decomposition::groups() const { - return groups_; -} +const std::vector& domain_decomposition::groups() const { return groups_; } const group_description& domain_decomposition::group(unsigned idx) const { arb_assert(idxdistributed->name(); } -ARB_ARBOR_API bool has_gpu(context ctx) { - return ctx->gpu->has_gpu(); -} - ARB_ARBOR_API unsigned num_threads(context ctx) { return ctx->thread_pool->get_num_threads(); } @@ -92,5 +88,19 @@ ARB_ARBOR_API bool has_mpi(context ctx) { return ctx->distributed->name() == "MPI"; } +ARB_ARBOR_API bool has_gpu(context ctx) { return ctx->gpu->has_gpu(); } + +ARB_ARBOR_API bool has_backend(context ctx, backend_kind be) { + if (backend_kind::gpu == be) { + return has_gpu(ctx); + } + else if (backend_kind::multicore == be) { + return true; + } else { + // Impossible. + throw std::runtime_error{"Unknown backend"}; + } +} + } // namespace arb diff --git a/arbor/include/arbor/common_types.hpp b/arbor/include/arbor/common_types.hpp index 34a4b2dddd..725978b137 100644 --- a/arbor/include/arbor/common_types.hpp +++ b/arbor/include/arbor/common_types.hpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -136,6 +137,26 @@ enum class ARB_SYMBOL_VISIBLE cell_kind { benchmark, // Proxy cell used for benchmarking. }; +inline ARB_ARBOR_API +std::string backend_kind_str(backend_kind bk) { + switch (bk) { + case backend_kind::gpu: return "gpu"; + case backend_kind::multicore: return "multicore"; + default: throw std::runtime_error{"Unknown backend"}; + } +} + +inline ARB_ARBOR_API +std::string cell_kind_str(cell_kind bk) { + switch (bk) { + case cell_kind::cable: return "cable"; + case cell_kind::lif: return "lif"; + case cell_kind::spike_source: return "spike_source"; + case cell_kind::benchmark: return "benchmark"; + default: throw std::runtime_error{"Unknown cell"}; + } +} + ARB_ARBOR_API std::ostream& operator<<(std::ostream& o, lid_selection_policy m); ARB_ARBOR_API std::ostream& operator<<(std::ostream& o, cell_member_type m); ARB_ARBOR_API std::ostream& operator<<(std::ostream& o, cell_kind k); diff --git a/arbor/include/arbor/context.hpp b/arbor/include/arbor/context.hpp index 5d87aab5d7..92663c41c3 100644 --- a/arbor/include/arbor/context.hpp +++ b/arbor/include/arbor/context.hpp @@ -1,5 +1,6 @@ #pragma once +#include "arbor/common_types.hpp" #include #include @@ -78,6 +79,7 @@ ARB_ARBOR_API context make_context(const proc_allocation& resources, Comm comm, // Queries for properties of execution resources in a context. ARB_ARBOR_API std::string distribution_type(context); +ARB_ARBOR_API bool has_backend(context, backend_kind); ARB_ARBOR_API bool has_gpu(context); ARB_ARBOR_API unsigned num_threads(context); ARB_ARBOR_API bool has_mpi(context); diff --git a/arbor/include/arbor/domain_decomposition.hpp b/arbor/include/arbor/domain_decomposition.hpp index 2706cd2935..96d2c0260c 100644 --- a/arbor/include/arbor/domain_decomposition.hpp +++ b/arbor/include/arbor/domain_decomposition.hpp @@ -39,7 +39,9 @@ struct group_description { class ARB_ARBOR_API domain_decomposition { public: domain_decomposition() = delete; - domain_decomposition(const recipe& rec, context ctx, const std::vector& groups); + domain_decomposition(const recipe& rec, + context ctx, + std::vector groups); domain_decomposition(const domain_decomposition&) = default; domain_decomposition& operator=(const domain_decomposition&) = default; @@ -54,10 +56,8 @@ class ARB_ARBOR_API domain_decomposition { const group_description& group(unsigned) const; private: - /// Return the domain id of cell with gid. - /// Supplied by the load balancing algorithm that generates the domain - /// decomposition. - std::function gid_domain_; + /// Obtain the domain id of cell with gid. + std::unordered_map gid_map_; /// Number of distributed domains int num_domains_; diff --git a/arbor/include/arbor/domdecexcept.hpp b/arbor/include/arbor/domdecexcept.hpp index aac8490733..22db4e6b1d 100644 --- a/arbor/include/arbor/domdecexcept.hpp +++ b/arbor/include/arbor/domdecexcept.hpp @@ -26,6 +26,13 @@ struct ARB_SYMBOL_VISIBLE duplicate_gid: dom_dec_exception { cell_gid_type gid; }; +struct ARB_SYMBOL_VISIBLE skipped_gid: dom_dec_exception { + skipped_gid(cell_gid_type gid, cell_gid_type nxt); + cell_gid_type gid; + cell_gid_type nxt; +}; + + struct ARB_SYMBOL_VISIBLE out_of_bounds: dom_dec_exception { out_of_bounds(cell_gid_type gid, unsigned num_cells); cell_gid_type gid; @@ -33,14 +40,16 @@ struct ARB_SYMBOL_VISIBLE out_of_bounds: dom_dec_exception { }; struct ARB_SYMBOL_VISIBLE invalid_backend: dom_dec_exception { - invalid_backend(int rank); + invalid_backend(int rank, backend_kind be); int rank; + backend_kind backend; }; struct ARB_SYMBOL_VISIBLE incompatible_backend: dom_dec_exception { - incompatible_backend(int rank, cell_kind kind); + incompatible_backend(int rank, cell_kind kind, backend_kind back); int rank; cell_kind kind; + backend_kind backend; }; } // namespace arb diff --git a/arbor/label_resolution.cpp b/arbor/label_resolution.cpp index dd928098b6..db1bf58550 100644 --- a/arbor/label_resolution.cpp +++ b/arbor/label_resolution.cpp @@ -227,5 +227,16 @@ cell_lid_type resolver::resolve(cell_gid_type gid, const cell_local_label_type& return *lid; } + +void resolver::reset() { + for (auto& [gid, tags]: state_map_) { + for (auto& [tag, states]: tags) { + states.clear(); + } + } +} + +void resolver::clear() { state_map_.clear(); } + } // namespace arb diff --git a/arbor/label_resolution.hpp b/arbor/label_resolution.hpp index 9b30a41fa3..f20d684743 100644 --- a/arbor/label_resolution.hpp +++ b/arbor/label_resolution.hpp @@ -116,6 +116,9 @@ struct ARB_ARBOR_API resolver { using state_variant = std::variant; + void clear(); + void reset(); + private: template using map = std::unordered_map;