From 8f7a855c14cc2659fa82d605d4741248da681790 Mon Sep 17 00:00:00 2001 From: Emiel Por Date: Thu, 12 Dec 2024 17:14:49 -0800 Subject: [PATCH] Add maximum key length check, safe string compare, and byte packing. --- benchmarks/hash_map.cpp | 17 ++++++++++++++--- catkit_core/HashMap.h | 30 +++++++++++++++++++++--------- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/benchmarks/hash_map.cpp b/benchmarks/hash_map.cpp index 82477c58..d4a29c30 100644 --- a/benchmarks/hash_map.cpp +++ b/benchmarks/hash_map.cpp @@ -5,9 +5,10 @@ int main(int argc, char **argv) { - typedef HashMap MyHashMap; + typedef HashMap MyHashMap; std::size_t buffer_size = MyHashMap::CalculateBufferSize(); + std::cout << "Buffer size: " << buffer_size << " bytes" << std::endl; char *buffer = new char[buffer_size]; MyHashMap map(buffer); @@ -21,9 +22,14 @@ int main(int argc, char **argv) std::string key = "key" + std::to_string(i); auto start = GetTimeStamp(); - map.Insert(key, i); + bool success = map.Insert(key, uint16_t(i)); auto end = GetTimeStamp(); + if (!success) + { + std::cout << "Insertion failed." << std::endl; + } + total_time += end - start; } @@ -36,9 +42,14 @@ int main(int argc, char **argv) std::string key = "key" + std::to_string(i); auto start = GetTimeStamp(); - const int *value = map.Find(key); + auto *value = map.Find(key); auto end = GetTimeStamp(); + if (value == nullptr || *value != i) + { + std::cout << "Key not found." << std::endl; + } + total_time += end - start; } diff --git a/catkit_core/HashMap.h b/catkit_core/HashMap.h index cabf698e..4e065375 100644 --- a/catkit_core/HashMap.h +++ b/catkit_core/HashMap.h @@ -6,6 +6,7 @@ #include #include #include +#include // MurmurHash3 32-bit version uint32_t murmurhash3(const std::string &key, uint32_t seed = 0) @@ -69,7 +70,7 @@ template class HashMap { private: - enum EntryFlags + enum EntryFlags : uint8_t { UNOCCUPIED = 0, INITIALIZING = 1, @@ -78,10 +79,10 @@ class HashMap struct Entry { + Value value; + std::atomic flags = EntryFlags::UNOCCUPIED; char key[MaxKeyLength]; - - Value value; }; Entry *m_Data; @@ -114,6 +115,12 @@ class HashMap bool Insert(const std::string &key, const Value &value) { + if (key.size() > MaxKeyLength) + { + // Key is too long to fit in the fixed-size buffer. + return false; + } + size_t index = hash(key); for (size_t i = 0; i < Size; ++i) @@ -139,7 +146,7 @@ class HashMap if (flags == EntryFlags::OCCUPIED) { // Check if the key is our key. - if (std::strcmp(m_Data[probe].key, key.c_str()) == 0) + if (AreKeysTheSame(m_Data[probe].key, key.c_str())) { // Key already exists. return false; @@ -148,10 +155,9 @@ class HashMap } else { - // Copy key ensuring null-termination. - std::size_t key_length = std::min(key.size(), MaxKeyLength - 1); + // Copy key. + std::size_t key_length = std::min(key.size(), MaxKeyLength); key.copy(m_Data[probe].key, key_length); - m_Data[probe].key[MaxKeyLength - 1] = '\0'; // Copy m_Data. m_Data[probe].value = value; @@ -169,7 +175,7 @@ class HashMap const Value *Find(const std::string &key) const { - if (key.size() >= MaxKeyLength) + if (key.size() > MaxKeyLength) { // Key is too long to fit in the fixed-size buffer. return nullptr; @@ -183,13 +189,14 @@ class HashMap EntryFlags flags = m_Data[probe].flags.load(std::memory_order_acquire); - if (flags == EntryFlags::OCCUPIED && std::strcmp(m_Data[probe].key, key.c_str()) == 0) + if (flags == EntryFlags::OCCUPIED && AreKeysTheSame(m_Data[probe].key, key.c_str())) { return &m_Data[probe].value; } if (flags != EntryFlags::OCCUPIED) { + // Key not found. break; } } @@ -197,6 +204,11 @@ class HashMap // Key not found. return nullptr; } + + bool AreKeysTheSame(const char *ky1, const char *ky2) const + { + return std::strncmp(ky1, ky2, MaxKeyLength) == 0; + } }; #endif // HASH_MAP_H