Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor domain_decomposition. #2210

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 172 additions & 79 deletions arbor/communication/communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,125 +55,218 @@ cell_member_type global_cell_of(const cell_member_type& c) {
return {c.gid | msb, c.index};
}

void communicator::update_connections(const connectivity& rec,
const domain_decomposition& dom_dec,
const label_resolution_map& source_resolution_map,
const label_resolution_map& target_resolution_map) {
PE(init:communicator:update:clear);
// Forget all lingering information
connections_.clear();
connection_part_.clear();
index_divisions_.clear();
PL();

// Make a list of local cells' connections
// -> gid_connections
// Count the number of local connections (i.e. connections terminating on this domain)
// -> n_cons: scalar
// Calculate and store domain id of the presynaptic cell on each local connection
// -> src_domains: array with one entry for every local connection
// Also the count of presynaptic sources from each domain
// -> src_counts: array with one entry for each domain

// Record all the gid in a flat vector.

PE(init:communicator:update:collect_gids);
std::vector<cell_gid_type> gids; gids.reserve(num_local_cells_);
for (const auto& g: dom_dec.groups()) util::append(gids, g.gids);
PL();

// Build the connection information for local cells.
PE(init:communicator:update:gid_connections);
// Build local(ie Arbor to Arbor) connection list
// Writes
// * connections := [connection]
// * connection_part := [index into connections]
// - such that all connections _from the nth source domain_ are located
// between connections_part[n] and connections_part[n+1] in connections.
// - source domains are the MPI ranks associated with the gid of the source
// of a connection.
// - as the spike buffer is sorted and partitioned by said source domain, we
// can use this to quickly filter spike buffer for spikes relevant to us.
// * index_divisions_ and index_part_
// - index_part_ is used to map a cell group index to a range of queue indices
// - these indices identify the queue in simulation belonging to cell the nth cell group
// - queue stores incoming events for a cell
// - events are not identical to spikes, but constructed from them
// - this indirection is needed as communicator/simulation is responsible for multiple
// cell groups.
// - index_part_ is a view onto index_divisions_. The latter is not directly used, but is
// the backing data of the former. (Essentially index_part[n] = range(index_div[n], index_div[n+1]))
void update_local_connections(const connectivity& rec,
const domain_decomposition& dec,
const std::vector<cell_gid_type>& gids,
size_t num_total_cells,
size_t num_local_cells,
size_t num_domains,
// Outputs; written into communicator
std::vector<connection>& connections_,
std::vector<cell_size_type>& connection_part_,
std::vector<cell_size_type>& index_divisions_,
util::partition_view_type<std::vector<cell_size_type>>& index_part_,
task_system_handle thread_pool_,
// Mutable state for label resolution.
resolver& target_resolver,
resolver& source_resolver) {
PE(init:communicator:update:local:gid_connections);
// List all connections and partition them by their _target cell's index_
std::vector<cell_connection> gid_connections;
std::vector<ext_cell_connection> gid_ext_connections;
std::vector<size_t> part_connections;
part_connections.reserve(num_local_cells_);
part_connections.reserve(num_local_cells);
part_connections.push_back(0);
std::vector<size_t> part_ext_connections;
part_ext_connections.reserve(num_local_cells_);
part_ext_connections.push_back(0);
// Map connection _index_ to the id of the source gid's domain.
// eg:
// Our gids [23, 42], indices [0, 1] and #domain 3
// Connections [ 42 <- 0, 42 <- 1, 23 <- 5, 42 <- 23, 23 <- 1]
// Domains [[0, 1, 2, 3], [4, 5], [...], [23, 42]]
// Thus we get
// Src Domains [ 0, 1, 3, 0]
// Src Counts [ 2, 1, 0, 1]
// Partitition [ 0 2 5 ]
std::vector<unsigned> src_domains;
std::vector<cell_size_type> src_counts(num_domains_);
std::vector<cell_size_type> src_counts(num_domains);

// Build the data structures above.
for (const auto gid: gids) {
// Local
const auto& conns = rec.connections_on(gid);
for (const auto& conn: conns) {
const auto sgid = conn.source.gid;
if (sgid >= num_total_cells_) throw arb::bad_connection_source_gid(gid, sgid, num_total_cells_);
const auto src = dom_dec.gid_domain(sgid);
const auto src_gid = conn.source.gid;
if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(gid, src_gid);
if (src_gid >= num_total_cells) throw arb::bad_connection_source_gid(gid, src_gid, num_total_cells);
const auto src = dec.gid_domain(src_gid);
src_domains.push_back(src);
src_counts[src]++;
gid_connections.emplace_back(conn);
}
part_connections.push_back(gid_connections.size());
// Remote
const auto& ext_conns = rec.external_connections_on(gid);
for (const auto& conn: ext_conns) {
gid_ext_connections.emplace_back(conn);
}
part_ext_connections.push_back(gid_ext_connections.size());
}

// Construct partitioning of connections on src_domains, thus
// continuing the above example:
// connection_part_ [ 0 2 3 3 4]
// mapping the ranges
// [0-2, 2-3, 3-3, 3-4]
// in the to-be-created connection array.
util::make_partition(connection_part_, src_counts);
auto n_cons = gid_connections.size();
auto n_ext_cons = gid_ext_connections.size();
PL();

// Construct the connections. The loop above gave us the information needed
// to do this in place.
// NOTE: The connections are partitioned by the domain of their source gid.
PE(init:communicator:update:connections);
connections_.resize(n_cons);
ext_connections_.resize(n_ext_cons);
auto offsets = connection_part_; // Copy, as we use this as the list of current target indices to write into
std::size_t ext = 0;
auto src_domain = src_domains.begin();
auto target_resolver = resolver(&target_resolution_map);
for (const auto index: util::make_span(num_local_cells_)) {
PE(init:communicator:update:local:connections);
connections_.resize(gid_connections.size());
// Copy, as we use this as the list of current target indices to write into
struct offset_t {
std::vector<cell_size_type>::iterator source;
std::vector<cell_size_type> offsets;
cell_size_type next() { return offsets[*source++]++; }
};

auto offsets = offset_t{src_domains.begin(), connection_part_};
for (const auto index: util::make_span(num_local_cells)) {
const auto tgt_gid = gids[index];
auto source_resolver = resolver(&source_resolution_map);
for (const auto cidx: util::make_span(part_connections[index], part_connections[index+1])) {
for (const auto cidx: util::make_span(part_connections[index],
part_connections[index+1])) {
const auto& conn = gid_connections[cidx];
auto src_gid = conn.source.gid;
if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(tgt_gid, src_gid);
auto src_lid = source_resolver.resolve(conn.source);
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
auto offset = offsets[*src_domain]++;
++src_domain;
connections_[offset] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, index};
}
for (const auto cidx: util::make_span(part_ext_connections[index], part_ext_connections[index+1])) {
const auto& conn = gid_ext_connections[cidx];
auto src = global_cell_of(conn.source);
auto src_gid = conn.source.rid;
if(is_external(src_gid)) throw arb::source_gid_exceeds_limit(tgt_gid, src_gid);
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
ext_connections_[ext] = {src, tgt_lid, conn.weight, conn.delay, index};
++ext;
auto out_idx = offsets.next();
connections_[out_idx] = {{src_gid, src_lid}, tgt_lid, conn.weight, conn.delay, (cell_size_type)index};
}
source_resolver.reset();
}
PL();

PE(init:communicator:update:index);
PE(init:communicator:update:local:index);
// Build cell partition by group for passing events to cell groups
index_part_ = util::make_partition(index_divisions_,
util::transform_view(
dom_dec.groups(),
[](const group_description& g){ return g.gids.size(); }));
util::transform_view(dec.groups(),
[](const auto& g){ return g.gids.size(); }));
PL();

PE(init:communicator:update:sort_connections);
PE(init:communicator:update:local:sort_connections);
// Sort the connections for each domain.
// This is num_domains_ independent sorts, so it can be parallelized trivially.
const auto& cp = connection_part_;
threading::parallel_for::apply(0, num_domains_, thread_pool_.get(),
threading::parallel_for::apply(0, num_domains, thread_pool_.get(),
[&](cell_size_type i) {
util::sort(util::subrange_view(connections_, cp[i], cp[i+1]));
util::sort(util::subrange_view(connections_,
connection_part_[i],
connection_part_[i+1]));
});
PL();
}

// Build lists for the _remote_ connections. No fancy acceleration structures
// are built and the list is globally sorted.
void update_remote_connections(const connectivity& rec,
const domain_decomposition& dec,
const std::vector<cell_gid_type>& gids,
size_t num_total_cells,
size_t num_local_cells,
size_t num_domains,
// Outputs; written into communicator
std::vector<connection>& ext_connections_,
// Mutable state for label resolution.
resolver& target_resolver,
resolver& source_resolver) {
PE(init:communicator:update:remote:gid_connections);
std::vector<ext_cell_connection> gid_ext_connections;
std::vector<size_t> part_ext_connections;
part_ext_connections.reserve(num_local_cells);
part_ext_connections.push_back(0);
for (const auto gid: gids) {
const auto& ext_conns = rec.external_connections_on(gid);
for (const auto& conn: ext_conns) {
// NOTE: This might look like a bug, but the _remote id_ is consider locally
// in the remote id space, ie must not be already tagged as remote.
if(is_external(conn.source.rid)) throw arb::source_gid_exceeds_limit(gid, conn.source.rid);
gid_ext_connections.emplace_back(conn);
}
part_ext_connections.push_back(gid_ext_connections.size());
}
PL();

// Construct the connections. The loop above gave us the information needed
// to do this in place.
PE(init:communicator:update:remote:connections);
ext_connections_.resize(gid_ext_connections.size());
std::size_t ext = 0;
for (const auto index: util::make_span(num_local_cells)) {
const auto tgt_gid = gids[index];
for (const auto cidx: util::make_span(part_ext_connections[index],
part_ext_connections[index+1])) {
const auto& conn = gid_ext_connections[cidx];
auto src = global_cell_of(conn.source);
auto tgt_lid = target_resolver.resolve(tgt_gid, conn.target);
ext_connections_[ext] = {src, tgt_lid, conn.weight, conn.delay, (cell_size_type) index};
++ext;
}
source_resolver.reset();
}
PL();

PE(init:communicator:update:remote:sort_connections);
std::sort(ext_connections_.begin(), ext_connections_.end());
PL();
}

void communicator::update_connections(const connectivity& rec,
const domain_decomposition& dom_dec,
const label_resolution_map& source_resolution_map,
const label_resolution_map& target_resolution_map) {
PE(init:communicator:update:clear);
// Forget all lingering information
connections_.clear();
connection_part_.clear();
ext_connections_.clear();
index_divisions_.clear();
PL();

// Remember list of local gids
PE(init:communicator:update:collect_gids);
std::vector<cell_gid_type> gids; gids.reserve(num_local_cells_);
for (const auto& g: dom_dec.groups()) util::append(gids, g.gids);
PL();

// Build resolvers
auto target_resolver = resolver(&target_resolution_map);
auto source_resolver = resolver(&source_resolution_map);

update_local_connections(rec, dom_dec, gids,
num_total_cells_, num_local_cells_, num_domains_,
connections_,
connection_part_, index_divisions_,
index_part_,
thread_pool_,
target_resolver, source_resolver);

update_remote_connections(rec, dom_dec, gids,
num_total_cells_, num_local_cells_, num_domains_,
ext_connections_,
target_resolver, source_resolver);
}

std::pair<cell_size_type, cell_size_type> communicator::group_queue_range(cell_size_type i) {
arb_assert(i<num_local_groups_);
return index_part_[i];
Expand Down
7 changes: 3 additions & 4 deletions arbor/communication/communicator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,9 @@ class ARB_ARBOR_API communicator {
/// all events that must be delivered to targets in that cell group as a
/// result of the global spike exchange, plus any events that were already
/// in the list.
void make_event_queues(
const gathered_vector<spike>& global_spikes,
std::vector<pse_vector>& queues,
const std::vector<spike>& external_spikes={});
void make_event_queues(const gathered_vector<spike>& global_spikes,
std::vector<pse_vector>& queues,
const std::vector<spike>& external_spikes={});

/// Returns the total number of global spikes over the duration of the simulation
std::uint64_t num_spikes() const;
Expand Down
Loading
Loading