Skip to content

Commit

Permalink
[GraphBolt] Make CachePolicy hash table types generic. (#7607)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jul 28, 2024
1 parent 1202d22 commit def2a1b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
10 changes: 5 additions & 5 deletions graphbolt/src/cache_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ std::tuple<torch::Tensor, torch::Tensor> BaseCachePolicy::ReplaceImpl(
sizeof(CacheKey*) == sizeof(int64_t), "You need 64 bit pointers.");
auto pointers_ptr =
reinterpret_cast<CacheKey**>(pointers.data_ptr<int64_t>());
phmap::flat_hash_set<int64_t> position_set;
set_t<int64_t> position_set;
position_set.reserve(keys.size(0));
for (int64_t i = 0; i < keys.size(0); i++) {
const auto key = keys_ptr[i];
Expand Down Expand Up @@ -132,7 +132,7 @@ S3FifoCachePolicy::S3FifoCachePolicy(int64_t capacity)
small_queue_size_target_(capacity / 10) {
TORCH_CHECK(small_queue_size_target_ > 0, "Capacity is not large enough.");
ghost_set_.reserve(ghost_queue_.Capacity());
key_to_cache_key_.reserve(capacity);
key_to_cache_key_.reserve(kCapacityFactor * capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Expand All @@ -157,7 +157,7 @@ SieveCachePolicy::SieveCachePolicy(int64_t capacity)
// Ensure that queue_ is constructed first before accessing its `.end()`.
: queue_(), hand_(queue_.end()), capacity_(capacity), cache_usage_(0) {
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
key_to_cache_key_.reserve(capacity);
key_to_cache_key_.reserve(kCapacityFactor * capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Expand All @@ -181,7 +181,7 @@ void SieveCachePolicy::WritingCompleted(torch::Tensor keys) {
LruCachePolicy::LruCachePolicy(int64_t capacity)
: capacity_(capacity), cache_usage_(0) {
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
key_to_cache_key_.reserve(capacity);
key_to_cache_key_.reserve(kCapacityFactor * capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Expand All @@ -205,7 +205,7 @@ void LruCachePolicy::WritingCompleted(torch::Tensor keys) {
ClockCachePolicy::ClockCachePolicy(int64_t capacity)
: queue_(capacity), capacity_(capacity), cache_usage_(0) {
TORCH_CHECK(capacity > 0, "Capacity needs to be positive.");
key_to_cache_key_.reserve(capacity);
key_to_cache_key_.reserve(kCapacityFactor * capacity);
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
Expand Down
16 changes: 11 additions & 5 deletions graphbolt/src/cache_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ class BaseCachePolicy {
virtual void WritingCompleted(torch::Tensor pointers) = 0;

protected:
template <typename K, typename V>
using map_t = phmap::flat_hash_map<K, V>;
template <typename K>
using set_t = phmap::flat_hash_set<K>;
static constexpr int kCapacityFactor = 2;

template <typename CachePolicy>
static std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
QueryImpl(CachePolicy& policy, torch::Tensor keys);
Expand Down Expand Up @@ -290,8 +296,8 @@ class S3FifoCachePolicy : public BaseCachePolicy {
int64_t capacity_;
int64_t cache_usage_;
int64_t small_queue_size_target_;
phmap::flat_hash_set<int64_t> ghost_set_;
phmap::flat_hash_map<int64_t, CacheKey*> key_to_cache_key_;
set_t<int64_t> ghost_set_;
map_t<int64_t, CacheKey*> key_to_cache_key_;
};

/**
Expand Down Expand Up @@ -383,7 +389,7 @@ class SieveCachePolicy : public BaseCachePolicy {
decltype(queue_)::iterator hand_;
int64_t capacity_;
int64_t cache_usage_;
phmap::flat_hash_map<int64_t, CacheKey*> key_to_cache_key_;
map_t<int64_t, CacheKey*> key_to_cache_key_;
};

/**
Expand Down Expand Up @@ -485,7 +491,7 @@ class LruCachePolicy : public BaseCachePolicy {
std::list<CacheKey> queue_;
int64_t capacity_;
int64_t cache_usage_;
phmap::flat_hash_map<int64_t, decltype(queue_)::iterator> key_to_cache_key_;
map_t<int64_t, decltype(queue_)::iterator> key_to_cache_key_;
};

/**
Expand Down Expand Up @@ -573,7 +579,7 @@ class ClockCachePolicy : public BaseCachePolicy {
CircularQueue<CacheKey> queue_;
int64_t capacity_;
int64_t cache_usage_;
phmap::flat_hash_map<int64_t, CacheKey*> key_to_cache_key_;
map_t<int64_t, CacheKey*> key_to_cache_key_;
};

} // namespace storage
Expand Down

0 comments on commit def2a1b

Please sign in to comment.