Skip to content

Commit

Permalink
use batch idx and node id as unique key for dedup in temporal sampling (
Browse files Browse the repository at this point in the history
#267)

Co-authored-by: Dong Wang <[email protected]>
  • Loading branch information
yaoyaowd and Dong Wang authored Aug 11, 2022
1 parent 916ba55 commit 6143af2
Showing 1 changed file with 73 additions and 22 deletions.
95 changes: 73 additions & 22 deletions csrc/cpu/neighbor_sample_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ using namespace std;

namespace {

typedef phmap::flat_hash_map<pair<int64_t, int64_t>, int64_t> temporarl_edge_dict;

template <bool replace, bool directed>
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample(const torch::Tensor &colptr, const torch::Tensor &row,
Expand Down Expand Up @@ -146,11 +148,15 @@ hetero_sample(const vector<node_t> &node_types,

// Initialize some data structures for the sampling process:
phmap::flat_hash_map<node_t, vector<int64_t>> samples_dict;
phmap::flat_hash_map<node_t, vector<pair<int64_t, int64_t>>> temp_samples_dict;
phmap::flat_hash_map<node_t, phmap::flat_hash_map<int64_t, int64_t>> to_local_node_dict;
phmap::flat_hash_map<node_t, temporarl_edge_dict> temp_to_local_node_dict;
phmap::flat_hash_map<node_t, vector<int64_t>> root_time_dict;
for (const auto &node_type : node_types) {
samples_dict[node_type];
temp_samples_dict[node_type];
to_local_node_dict[node_type];
temp_to_local_node_dict[node_type];
root_time_dict[node_type];
}

Expand All @@ -175,20 +181,33 @@ hetero_sample(const vector<node_t> &node_types,
}

auto &samples = samples_dict.at(node_type);
auto &temp_samples = temp_samples_dict.at(node_type);
auto &to_local_node = to_local_node_dict.at(node_type);
auto &temp_to_local_node = temp_to_local_node_dict.at(node_type);
auto &root_time = root_time_dict.at(node_type);
for (int64_t i = 0; i < input_node.numel(); i++) {
const auto &v = input_node_data[i];
samples.push_back(v);
to_local_node.insert({v, i});
if (temporal) {
temp_samples.push_back({v, i});
temp_to_local_node.insert({{v, i}, i});
} else {
samples.push_back(v);
to_local_node.insert({v, i});
}
if (temporal)
root_time.push_back(node_time_data[v]);
}
}

phmap::flat_hash_map<node_t, pair<int64_t, int64_t>> slice_dict;
for (const auto &kv : samples_dict)
slice_dict[kv.first] = {0, kv.second.size()};
if (temporal) {
for (const auto &kv : temp_samples_dict) {
slice_dict[kv.first] = {0, kv.second.size()};
}
} else {
for (const auto &kv : samples_dict)
slice_dict[kv.first] = {0, kv.second.size()};
}

vector<rel_t> all_rel_types;
for (const auto &kv : num_neighbors_dict) {
Expand All @@ -203,8 +222,11 @@ hetero_sample(const vector<node_t> &node_types,
const auto &dst_node_type = get<2>(edge_type);
const auto num_samples = num_neighbors_dict.at(rel_type)[ell];
const auto &dst_samples = samples_dict.at(dst_node_type);
const auto &temp_dst_samples = temp_samples_dict.at(dst_node_type);
auto &src_samples = samples_dict.at(src_node_type);
auto &temp_src_samples = temp_samples_dict.at(src_node_type);
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
auto &temp_to_local_src_node = temp_to_local_node_dict.at(src_node_type);

const torch::Tensor &colptr = colptr_dict.at(rel_type);
const auto *colptr_data = colptr.data_ptr<int64_t>();
Expand All @@ -223,7 +245,8 @@ hetero_sample(const vector<node_t> &node_types,
const auto &begin = slice_dict.at(dst_node_type).first;
const auto &end = slice_dict.at(dst_node_type).second;
for (int64_t i = begin; i < end; i++) {
const auto &w = dst_samples[i];
const auto &w = temporal ? temp_dst_samples[i].first : dst_samples[i];
const int64_t root_w = temporal ? temp_dst_samples[i].second : -1;
int64_t dst_time = 0;
if (temporal)
dst_time = dst_root_time[i];
Expand All @@ -241,15 +264,18 @@ hetero_sample(const vector<node_t> &node_types,
if (temporal) {
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
// force disjoint of computation tree
// force disjoint of computation tree based on source batch idx.
// note that the sampling always needs to have directed=True
// for temporal case
// to_local_src_node is not used for temporal / directed case
const int64_t sample_idx = src_samples.size();
src_samples.push_back(v);
src_root_time.push_back(dst_time);
const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
if (res.second) {
temp_src_samples.push_back({v, root_w});
src_root_time.push_back(dst_time);
}

cols.push_back(i);
rows.push_back(sample_idx);
rows.push_back(res.first->second);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
Expand All @@ -272,14 +298,17 @@ hetero_sample(const vector<node_t> &node_types,
// TODO Infinity loop if no neighbor satisfies time constraint:
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
// force disjoint of computation tree
// force disjoint of computation tree based on source batch idx.
// note that the sampling always needs to have directed=True
// for temporal case
const int64_t sample_idx = src_samples.size();
src_samples.push_back(v);
src_root_time.push_back(dst_time);
const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
if (res.second) {
temp_src_samples.push_back({v, root_w});
src_root_time.push_back(dst_time);
}

cols.push_back(i);
rows.push_back(sample_idx);
rows.push_back(res.first->second);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
Expand Down Expand Up @@ -307,14 +336,17 @@ hetero_sample(const vector<node_t> &node_types,
if (temporal) {
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
// force disjoint of computation tree
// force disjoint of computation tree based on source batch idx.
// note that the sampling always needs to have directed=True
// for temporal case
const int64_t sample_idx = src_samples.size();
src_samples.push_back(v);
src_root_time.push_back(dst_time);
const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
if (res.second) {
temp_src_samples.push_back({v, root_w});
src_root_time.push_back(dst_time);
}

cols.push_back(i);
rows.push_back(sample_idx);
rows.push_back(res.first->second);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
Expand All @@ -331,11 +363,18 @@ hetero_sample(const vector<node_t> &node_types,
}
}

for (const auto &kv : samples_dict) {
slice_dict[kv.first] = {slice_dict.at(kv.first).second, kv.second.size()};
if (temporal) {
for (const auto &kv : temp_samples_dict) {
slice_dict[kv.first] = {0, kv.second.size()};
}
} else {
for (const auto &kv : samples_dict)
slice_dict[kv.first] = {0, kv.second.size()};
}
}

// Temporal sample disable undirected
assert(!(temporal && !directed));
if (!directed) { // Construct the subgraph among the sampled nodes:
phmap::flat_hash_map<int64_t, int64_t>::iterator iter;
for (const auto &kv : colptr_dict) {
Expand Down Expand Up @@ -371,6 +410,18 @@ hetero_sample(const vector<node_t> &node_types,
}
}

// Construct samples dictionary from temporal sample dictionary.
if (temporal) {
for (const auto &kv : temp_samples_dict) {
const auto &node_type = kv.first;
const auto &samples = kv.second;
samples_dict[node_type].reserve(samples.size());
for (const auto &v : samples) {
samples_dict[node_type].push_back(v.first);
}
}
}

return make_tuple(from_vector<node_t, int64_t>(samples_dict),
from_vector<rel_t, int64_t>(rows_dict),
from_vector<rel_t, int64_t>(cols_dict),
Expand Down

0 comments on commit 6143af2

Please sign in to comment.