Skip to content

Commit

Permalink
[GraphBolt] Optimize CachePolicy even further. (#7606)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jul 28, 2024
1 parent ee20edb commit 1202d22
Show file tree
Hide file tree
Showing 6 changed files with 287 additions and 181 deletions.
77 changes: 49 additions & 28 deletions graphbolt/src/cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) {
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
auto filtered_keys = torch::empty_like(
auto found_ptr_tensor = torch::empty_like(
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
auto missing_keys = torch::empty_like(
keys, keys.options().pinned_memory(utils::is_pinned(keys)));
int64_t found_cnt = 0;
int64_t missing_cnt = keys.size(0);
Expand All @@ -44,66 +48,79 @@ BaseCachePolicy::QueryImpl(CachePolicy& policy, torch::Tensor keys) {
auto keys_ptr = keys.data_ptr<index_t>();
auto positions_ptr = positions.data_ptr<int64_t>();
auto indices_ptr = indices.data_ptr<int64_t>();
auto filtered_keys_ptr = filtered_keys.data_ptr<index_t>();
std::lock_guard lock(*policy.mtx_);
static_assert(
sizeof(CacheKey*) == sizeof(int64_t), "You need 64 bit pointers.");
auto found_ptr =
reinterpret_cast<CacheKey**>(found_ptr_tensor.data_ptr<int64_t>());
auto missing_keys_ptr = missing_keys.data_ptr<index_t>();
for (int64_t i = 0; i < keys.size(0); i++) {
const auto key = keys_ptr[i];
auto pos = policy.template Read<false>(key);
if (pos.has_value()) {
positions_ptr[found_cnt] = *pos;
filtered_keys_ptr[found_cnt] = key;
auto res = policy.template Read<false>(key);
if (res.has_value()) {
const auto [pos, cache_key_ptr] = *res;
positions_ptr[found_cnt] = pos;
found_ptr[found_cnt] = cache_key_ptr;
indices_ptr[found_cnt++] = i;
} else {
indices_ptr[--missing_cnt] = i;
filtered_keys_ptr[missing_cnt] = key;
missing_keys_ptr[missing_cnt] = key;
}
}
}));
return {
positions.slice(0, 0, found_cnt), indices,
filtered_keys.slice(0, found_cnt), filtered_keys.slice(0, 0, found_cnt)};
missing_keys.slice(0, found_cnt),
found_ptr_tensor.slice(0, 0, found_cnt)};
}

template <typename CachePolicy>
torch::Tensor BaseCachePolicy::ReplaceImpl(
std::tuple<torch::Tensor, torch::Tensor> BaseCachePolicy::ReplaceImpl(
CachePolicy& policy, torch::Tensor keys) {
auto positions = torch::empty_like(
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
auto pointers = torch::empty_like(
keys, keys.options()
.dtype(torch::kInt64)
.pinned_memory(utils::is_pinned(keys)));
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "BaseCachePolicy::Replace", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
auto positions_ptr = positions.data_ptr<int64_t>();
static_assert(
sizeof(CacheKey*) == sizeof(int64_t), "You need 64 bit pointers.");
auto pointers_ptr =
reinterpret_cast<CacheKey**>(pointers.data_ptr<int64_t>());
phmap::flat_hash_set<int64_t> position_set;
position_set.reserve(keys.size(0));
std::lock_guard lock(*policy.mtx_);
for (int64_t i = 0; i < keys.size(0); i++) {
const auto key = keys_ptr[i];
const auto pos_optional = policy.template Read<true>(key);
const auto pos = pos_optional ? *pos_optional : policy.Insert(key);
const auto res_optional = policy.template Read<true>(key);
const auto [pos, cache_key_ptr] =
res_optional ? *res_optional : policy.Insert(key);
positions_ptr[i] = pos;
pointers_ptr[i] = cache_key_ptr;
TORCH_CHECK(
// If there are duplicate values and the key was just inserted,
// we do not have to check for the uniqueness of the positions.
pos_optional.has_value() || std::get<1>(position_set.insert(pos)),
res_optional.has_value() || std::get<1>(position_set.insert(pos)),
"Can't insert all, larger cache capacity is needed.");
}
}));
return positions;
return {positions, pointers};
}

template <bool write, typename CachePolicy>
void BaseCachePolicy::ReadingWritingCompletedImpl(
CachePolicy& policy, torch::Tensor keys) {
AT_DISPATCH_INDEX_TYPES(
keys.scalar_type(), "BaseCachePolicy::ReadingCompleted", ([&] {
auto keys_ptr = keys.data_ptr<index_t>();
std::lock_guard lock(*policy.mtx_);
for (int64_t i = 0; i < keys.size(0); i++) {
policy.template Unmark<write>(keys_ptr[i]);
}
}));
CachePolicy& policy, torch::Tensor pointers) {
static_assert(
sizeof(CacheKey*) == sizeof(int64_t), "You need 64 bit pointers.");
auto pointers_ptr =
reinterpret_cast<CacheKey**>(pointers.data_ptr<int64_t>());
for (int64_t i = 0; i < pointers.size(0); i++) {
policy.template Unmark<write>(pointers_ptr[i]);
}
}

S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity)
Expand All @@ -123,7 +140,8 @@ S3FifoCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

torch::Tensor S3FifoCachePolicy::Replace(torch::Tensor keys) {
std::tuple<torch::Tensor, torch::Tensor> S3FifoCachePolicy::Replace(
torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

Expand All @@ -147,7 +165,8 @@ SieveCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

torch::Tensor SieveCachePolicy::Replace(torch::Tensor keys) {
std::tuple<torch::Tensor, torch::Tensor> SieveCachePolicy::Replace(
torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

Expand All @@ -170,7 +189,8 @@ LruCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

torch::Tensor LruCachePolicy::Replace(torch::Tensor keys) {
std::tuple<torch::Tensor, torch::Tensor> LruCachePolicy::Replace(
torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

Expand All @@ -193,7 +213,8 @@ ClockCachePolicy::Query(torch::Tensor keys) {
return QueryImpl(*this, keys);
}

torch::Tensor ClockCachePolicy::Replace(torch::Tensor keys) {
std::tuple<torch::Tensor, torch::Tensor> ClockCachePolicy::Replace(
torch::Tensor keys) {
return ReplaceImpl(*this, keys);
}

Expand Down
Loading

0 comments on commit 1202d22

Please sign in to comment.