diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 833dbc7c6c5c..74fa0c01ddd0 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -40,6 +40,7 @@ cxx_test(interpreter_test dfly_core LABELS DFLY) cxx_test(string_set_test dfly_core LABELS DFLY) cxx_test(string_map_test dfly_core LABELS DFLY) +cxx_test(oah_set_test dfly_core LABELS DFLY) cxx_test(sorted_map_test dfly_core redis_test_lib LABELS DFLY) cxx_test(bptree_set_test dfly_core LABELS DFLY) cxx_test(score_map_test dfly_core LABELS DFLY) diff --git a/src/core/oah_entry.h b/src/core/oah_entry.h new file mode 100644 index 000000000000..45e9542897d6 --- /dev/null +++ b/src/core/oah_entry.h @@ -0,0 +1,433 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include +#include +#include + +#include "base/hash.h" +#include "base/logging.h" + +extern "C" { +#include "redis/zmalloc.h" +} + +namespace dfly { + +static uint64_t Hash(std::string_view str) { + constexpr XXH64_hash_t kHashSeed = 24061983; + return XXH3_64bits_withSeed(str.data(), str.size(), kHashSeed); +} + +static uint32_t BucketId(uint64_t hash, uint32_t capacity_log) { + assert(capacity_log > 0); + return hash >> (64 - capacity_log); +} +// doesn't possess memory, it should be created and release manually +class OAHEntry { + // we can assume that high 12 bits of user address space + // can be used for tagging. At most 52 bits of address are reserved for + // some configurations, and usually it's 48 bits. + // https://docs.kernel.org/arch/arm64/memory.html + static constexpr size_t kExpiryBit = 1ULL << 52; + + // if bit is set the string length field is 1 byte instead of 4 + static constexpr size_t kSsoBit = 1ULL << 53; + static constexpr size_t kVectorBit = 1ULL << 54; + + // extended hash allows us to reduce keys comparisons + static constexpr size_t kExtHashShift = 56; + static constexpr uint32_t kExtHashSize = 8; + static constexpr size_t kExtHashMask = 0xFFULL; + static constexpr size_t kExtHashShiftedMask = kExtHashMask << kExtHashShift; + + static constexpr size_t kTagMask = 4095ULL << 52; // we reserve 12 high bits. + + public: + OAHEntry() = default; + + OAHEntry(std::string_view key, uint32_t expiry = UINT32_MAX) { + uint32_t key_size = key.size(); + + uint32_t expiry_size = (expiry != UINT32_MAX) * sizeof(expiry); + + uint32_t key_len_field_size = key_size <= std::numeric_limits::max() ? 1 : 4; + + auto size = key_len_field_size + key_size + expiry_size; + + data_ = (char*)zmalloc(size); + + auto* expiry_pos = data_; + if (expiry_size) { + SetExpiryBit(true); + std::memcpy(expiry_pos, &expiry, sizeof(expiry)); + } + + auto* key_size_pos = expiry_pos + expiry_size; + if (key_len_field_size == 1) { + SetSsoBit(); + uint8_t sso_key_size = key_size; + std::memcpy(key_size_pos, &sso_key_size, key_len_field_size); + } else { + std::memcpy(key_size_pos, &key_size, key_len_field_size); + } + + auto* key_pos = key_size_pos + key_len_field_size; + std::memcpy(key_pos, key.data(), key_size); + } + + // TODO add initializer list constructor + OAHEntry(size_t vector_size) { + // TODO rewrite to simple array + data_ = reinterpret_cast(new std::vector(vector_size)); + SetVectorBit(); + } + + OAHEntry(const OAHEntry& e) = delete; + OAHEntry(OAHEntry&& e) { + data_ = e.data_; + e.data_ = nullptr; + } + + // consider manual removing, we waste a lot of time to check nullptr + inline ~OAHEntry() { + Clear(); + } + + OAHEntry& operator=(const OAHEntry& e) = delete; + OAHEntry& operator=(OAHEntry&& e) { + std::swap(data_, e.data_); + return *this; + } + + inline bool Empty() const { + return !Raw(); + } + + inline operator bool() const { + return !Empty(); + } + + inline bool IsVector() const { + return (uptr() & kVectorBit) != 0; + } + + inline std::vector& AsVector() { + return *reinterpret_cast*>(Raw()); + } + + inline std::string_view Key() const { + DCHECK(!IsVector()); + return {GetKeyData(), GetKeySize()}; + } + + inline bool HasExpiry() const { + return (uptr() & kExpiryBit) != 0; + } + + // returns the expiry time of the current entry or UINT32_MAX if no expiry is set. + inline uint32_t GetExpiry() const { + std::uint32_t res = UINT32_MAX; + if (HasExpiry()) { + DCHECK(!IsVector()); + std::memcpy(&res, Raw(), sizeof(res)); + } + return res; + } + + // TODO consider another option to implement iterator + OAHEntry* operator->() { + return this; + } + + inline uint64_t GetHash() const { + return (uptr() & kExtHashShiftedMask) >> kExtHashShift; + } + + bool CheckBucketAffiliation(uint32_t bucket_id, uint32_t capacity_log, uint32_t shift_log) { + DCHECK(!IsVector()); + if (Empty()) + return false; + uint32_t bucket_id_hash_part = capacity_log > shift_log ? shift_log : capacity_log; + uint32_t bucket_mask = (1 << bucket_id_hash_part) - 1; + bucket_id &= bucket_mask; + auto stored_hash = GetHash(); + if (!stored_hash) { + stored_hash = SetHash(Hash(Key()), capacity_log, shift_log); + } + uint32_t stored_bucket_id = stored_hash >> (kExtHashSize - bucket_id_hash_part); + return bucket_id == stored_bucket_id; + } + + bool CheckExtendedHash(uint64_t hash, uint32_t capacity_log, uint32_t shift_log) { + if (Empty()) + return false; + const uint32_t start_hash_bit = capacity_log > shift_log ? capacity_log - shift_log : 0; + const uint32_t ext_hash_shift = 64 - start_hash_bit - kExtHashSize; + const uint64_t ext_hash = (hash >> ext_hash_shift) & kExtHashMask; + auto stored_hash = GetHash(); + if (!stored_hash && !IsVector()) { + stored_hash = SetHash(Hash(Key()), capacity_log, shift_log); + } + return stored_hash == ext_hash; + } + + // TODO rename to SetHash + // shift_log identify which bucket the element belongs to + uint64_t SetHash(uint64_t hash, uint32_t capacity_log, uint32_t shift_log) { + DCHECK(!IsVector()); + const uint32_t start_hash_bit = capacity_log > shift_log ? capacity_log - shift_log : 0; + const uint32_t ext_hash_shift = 64 - start_hash_bit - kExtHashSize; + const uint64_t result_hash = (hash >> ext_hash_shift) & kExtHashMask; + const uint64_t ext_hash = result_hash << kExtHashShift; + data_ = (char*)((uptr() & ~kExtHashShiftedMask) | ext_hash); + return result_hash; + } + + void ClearHash() { + data_ = (char*)((uptr() & ~kExtHashShiftedMask)); + } + + // return new bucket_id + uint32_t Rehash(uint32_t current_bucket_id, uint32_t prev_capacity_log, uint32_t new_capacity_log, + uint32_t shift_log) { + DCHECK(!IsVector()); + auto stored_hash = GetHash(); + + const uint32_t logs_diff = new_capacity_log - prev_capacity_log; + const uint32_t prev_significant_bits = + prev_capacity_log > shift_log ? shift_log : prev_capacity_log; + const uint32_t needed_hash_bits = prev_significant_bits + logs_diff; + + if (!stored_hash || needed_hash_bits > kExtHashSize) { + auto hash = Hash(Key()); + SetHash(hash, new_capacity_log, shift_log); + return BucketId(hash, new_capacity_log); + } + + const uint32_t real_bucket_end = stored_hash >> (kExtHashSize - prev_significant_bits); + const uint32_t prev_shift_mask = (1 << prev_significant_bits) - 1; + const uint32_t curr_shift = (current_bucket_id - real_bucket_end) & prev_shift_mask; + const uint32_t prev_bucket_mask = (1 << prev_capacity_log) - 1; + const uint32_t base_bucket_id = (current_bucket_id - curr_shift) & prev_bucket_mask; + + const uint32_t last_bits_mask = (1 << logs_diff) - 1; + const uint32_t stored_hash_shift = kExtHashSize - needed_hash_bits; + const uint32_t last_bits = (stored_hash >> stored_hash_shift) & last_bits_mask; + const uint32_t new_bucket_id = (base_bucket_id << logs_diff) | last_bits; + + ClearHash(); // the cache is invalid after rehash operation + + DCHECK_EQ(BucketId(Hash(Key()), new_capacity_log), new_bucket_id); + + return new_bucket_id; + } + + void SetExpiry(uint32_t at_sec) { + DCHECK(!IsVector()); + if (HasExpiry()) { + auto* expiry_pos = Raw(); + std::memcpy(expiry_pos, &at_sec, sizeof(at_sec)); + } else { + *this = OAHEntry(Key(), at_sec); + } + } + + // TODO refactor, because it's inefficient + std::optional Find(std::string_view str, uint64_t hash, uint32_t capacity_log, + uint32_t shift_log, uint32_t* set_size, uint32_t time_now = 0) { + if (Empty()) + return std::nullopt; + if (!IsVector()) { + ExpireIfNeeded(time_now, set_size); + return CheckExtendedHash(hash, capacity_log, shift_log) && Key() == str + ? 0 + : std::optional(); + } + auto& vec = AsVector(); + for (size_t i = 0, size = vec.size(); i < size; ++i) { + vec[i].ExpireIfNeeded(time_now, set_size); + if (vec[i].CheckExtendedHash(hash, capacity_log, shift_log) && vec[i].Key() == str) { + return i; + } + } + return std::nullopt; + } + + void ExpireIfNeeded(uint32_t time_now, uint32_t* set_size) { + DCHECK(!IsVector()); + if (GetExpiry() <= time_now) { + Clear(); + --*set_size; + } + } + + // TODO refactor, because it's inefficient + inline uint32_t Insert(OAHEntry&& e) { + if (Empty()) { + *this = std::move(e); + return 0; + } else if (!IsVector()) { + OAHEntry tmp(2); + auto& arr = tmp.AsVector(); + arr[0] = std::move(*this); + arr[1] = std::move(e); + *this = std::move(tmp); + return 1; + } else { + auto& arr = AsVector(); + size_t i = 0; + for (; i < arr.size(); ++i) { + if (!arr[i]) { + arr[i] = std::move(e); + } + } + arr.push_back(std::move(e)); + return arr.size() - 1; + } + } + + uint32_t ElementsNum() { + if (Empty()) { + return 0; + } else if (!IsVector()) { + return 1; + } + return AsVector().size(); + } + + // TODO remove, it is inefficient + inline OAHEntry& operator[](uint32_t pos) { + DCHECK(!Empty()); + if (!IsVector()) { + DCHECK(pos == 0); + return *this; + } else { + auto& arr = AsVector(); + DCHECK(pos < arr.size()); + return arr[pos]; + } + } + + OAHEntry Remove(uint32_t pos) { + if (Empty()) { + // I'm not sure that this scenario should be check at all + DCHECK(pos == 0); + return OAHEntry(); + } else if (!IsVector()) { + DCHECK(pos == 0); + return std::move(*this); + } else { + auto& arr = AsVector(); + DCHECK(pos < arr.size()); + return std::move(arr[pos]); + } + } + + OAHEntry Pop() { + if (IsVector()) { + auto& arr = AsVector(); + for (auto& e : arr) { + if (e) + return std::move(e); + } + } + return std::move(*this); + } + + template >* = nullptr> + bool Scan(const T& cb, uint32_t bucket_id, uint32_t capacity_log, uint32_t shift_log) { + if (!IsVector()) { + if (CheckBucketAffiliation(bucket_id, capacity_log, shift_log)) { + cb(Key()); + return true; + } + } else { + auto& arr = AsVector(); + bool result = false; + for (auto& el : arr) { + if (el.CheckBucketAffiliation(bucket_id, capacity_log, shift_log)) { + cb(el.Key()); + result = true; + } + } + return result; + } + return false; + } + + protected: + inline void Clear() { + // TODO add optimization to avoid destructor calls during vector allocator + if (!data_) + return; + + if (IsVector()) { + delete &AsVector(); + } else { + zfree(Raw()); + } + data_ = nullptr; + } + + const char* GetKeyData() const { + uint32_t key_field_size = HasSso() ? 1 : 4; + return Raw() + GetExpirySize() + key_field_size; + } + + uint32_t GetKeySize() const { + if (HasSso()) { + uint8_t size = 0; + std::memcpy(&size, Raw() + GetExpirySize(), sizeof(size)); + return size; + } + uint32_t size = 0; + std::memcpy(&size, Raw() + GetExpirySize(), sizeof(size)); + return size; + } + + inline uint64_t uptr() const { + return uint64_t(data_); + } + + inline char* Raw() const { + return (char*)(uptr() & ~kTagMask); + } + + inline void SetExpiryBit(bool b) { + if (b) + data_ = (char*)(uptr() | kExpiryBit); + else + data_ = (char*)(uptr() & (~kExpiryBit)); + } + + inline void SetVectorBit() { + data_ = (char*)(uptr() | kVectorBit); + } + + inline void SetSsoBit() { + data_ = (char*)(uptr() | kSsoBit); + } + + inline bool HasSso() const { + return (uptr() & kSsoBit) != 0; + } + + inline size_t Size() { + size_t key_field_size = HasSso() ? 1 : 4; + size_t expiry_field_size = HasExpiry() ? 4 : 0; + return expiry_field_size + key_field_size + GetKeySize(); + } + + inline std::uint32_t GetExpirySize() const { + return HasExpiry() ? sizeof(std::uint32_t) : 0; + } + + // memory daya layout [Expiry, key_size, key] + char* data_ = nullptr; +}; + +} // namespace dfly diff --git a/src/core/oah_set.h b/src/core/oah_set.h new file mode 100644 index 000000000000..d0c5a31b1930 --- /dev/null +++ b/src/core/oah_set.h @@ -0,0 +1,376 @@ +// Copyright 2024, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#pragma once + +#include +#include + +#include + +#include "base/pmr/memory_resource.h" +#include "oah_entry.h" + +namespace dfly { + +class OAHSet { // Open Addressing Hash Set + using Buckets = std::vector>; + + public: + class iterator { + public: + using iterator_category = std::forward_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = OAHEntry; + using pointer = OAHEntry*; + using reference = OAHEntry&; + + iterator(OAHSet* owner, uint32_t bucket_id, uint32_t pos_in_bucket) + : owner_(owner), bucket_(bucket_id), pos_(pos_in_bucket) { + // TODO rewrite, it's inefficient + SetEntryIt(); + } + + void SetExpiryTime(uint32_t ttl_sec, size_t* obj_malloc_used) { + owner_->entries_[bucket_][pos_].SetExpiry(owner_->EntryTTL(ttl_sec)); + } + + iterator& operator++() { + ++pos_; + SetEntryIt(); + return *this; + } + + bool operator==(const iterator& r) const { + if (owner_ == nullptr || r.owner_ == nullptr) { + return owner_ == r.owner_; + } + DCHECK(owner_ == r.owner_); + return bucket_ == r.bucket_ && pos_ == r.pos_; + } + + bool operator!=(const iterator& r) const { + return !operator==(r); + } + + reference operator*() { + return owner_->entries_[bucket_][pos_]; + } + + reference operator->() { + return owner_->entries_[bucket_][pos_]; + } + + bool HasExpiry() { + return owner_->entries_[bucket_][pos_].HasExpiry(); + } + + uint32_t ExpiryTime() { + return owner_->entries_[bucket_][pos_].GetExpiry(); + } + + operator bool() const { + return owner_; + } + + private: + // find valid entry_ iterator starting from buckets_it_ and set it + void SetEntryIt() { + if (!owner_) + return; + for (auto size = owner_->entries_.size(); bucket_ < size; ++bucket_) { + auto& bucket = owner_->entries_[bucket_]; + for (uint32_t bucket_size = bucket.ElementsNum(); pos_ < bucket_size; ++pos_) { + if (bucket[pos_]) + return; + } + pos_ = 0; + } + owner_ = nullptr; + } + + private: + OAHSet* owner_; + uint32_t bucket_; + uint32_t pos_; + }; + + iterator begin() { + return iterator(this, 0, 0); + } + + iterator end() { + return iterator(nullptr, 0, 0); + } + + explicit OAHSet(PMR_NS::memory_resource* mr = PMR_NS::get_default_resource()) : entries_(mr) { + } + + static constexpr uint32_t kMaxBatchLen = 32; + + bool Add(std::string_view str, uint32_t ttl_sec = UINT32_MAX) { + if (entries_.empty() || size_ >= entries_.size()) { + Reserve(Capacity() * 2); + } + uint64_t hash = Hash(str); + const auto bucket_id = BucketId(hash, capacity_log_); + + // TODO FindInternal and FindEmptyAround can be one function to get better performance + if (auto item = FindInternal(bucket_id, str, hash); item != end()) { + return false; + } + + uint32_t bucket = FindEmptyAround(bucket_id); + + DCHECK(bucket_id + kDisplacementSize > bucket); + + AddUnique(str, bucket, hash, ttl_sec); + return true; + } + + void Reserve(size_t sz) { + sz = absl::bit_ceil(sz); + if (sz > entries_.size()) { + auto prev_capacity_log = capacity_log_; + capacity_log_ = std::max(kMinCapacityLog, uint32_t(absl::bit_width(sz) - 1)); + entries_.resize(Capacity()); + Rehash(prev_capacity_log); + } + } + + void Clear() { + capacity_log_ = 0; + entries_.resize(0); + size_ = 0; + } + + iterator AddUnique(std::string_view str, uint32_t bucket, uint64_t hash, + uint32_t ttl_sec = UINT32_MAX) { + ++size_; + uint32_t at = EntryTTL(ttl_sec); + uint32_t pos = entries_[bucket].Insert(OAHEntry(str, at)); + entries_[bucket][pos].SetHash(hash, capacity_log_, kShiftLog); + return iterator(this, bucket, pos); + } + + unsigned AddMany(absl::Span span, uint32_t ttl_sec = UINT32_MAX) { + Reserve(span.size()); + unsigned res = 0; + for (auto& s : span) { + if (Add(s, ttl_sec) != end()) { + res++; + } + } + return res; + } + + // TODO: Consider using chunks for this as in StringSet + void Fill(OAHSet* other) { + DCHECK(other->entries_.empty()); + other->Reserve(UpperBoundSize()); + other->set_time(time_now()); + for (auto it = begin(), it_end = end(); it != it_end; ++it) { + other->Add(it->Key(), it.HasExpiry() ? it.ExpiryTime() - time_now() : UINT32_MAX); + } + } + + /** + * stable scanning api. has the same guarantees as redis scan command. + * we avoid doing bit-reverse by using a different function to derive a bucket id + * from hash values. By using msb part of hash we make it "stable" with respect to + * rehashes. For example, with table log size 4 (size 16), entries in bucket id + * 1110 come from hashes 1110XXXXX.... When a table grows to log size 5, + * these entries can move either to 11100 or 11101. So if we traversed with our cursor + * range [0000-1110], it's guaranteed that in grown table we do not need to cover again + * [00000-11100]. Similarly with shrinkage, if a table is shrunk to log size 3, + * keys from 1110 and 1111 will move to bucket 111. Again, it's guaranteed that we + * covered the range [000-111] (all keys in that case). + * Returns: next cursor or 0 if reached the end of scan. + * cursor = 0 - initiates a new scan. + */ + + using ItemCb = std::function; + + uint32_t Scan(uint32_t cursor, const ItemCb& cb) { + const uint32_t capacity_mask = Capacity() - 1; + uint32_t bucket_id = cursor >> (32 - capacity_log_); + const uint32_t displacement_size = std::min(kDisplacementSize, BucketCount()); + + // First find the bucket to scan, skip empty buckets. + for (; bucket_id < entries_.size(); ++bucket_id) { + bool res = false; + for (uint32_t i = 0; i < displacement_size; i++) { + const uint32_t shifted_bid = (bucket_id + i) & capacity_mask; + res |= entries_[shifted_bid].Scan(cb, bucket_id, capacity_log_, kShiftLog); + } + if (res) + break; + } + + if (++bucket_id >= entries_.size()) { + return 0; + } + + return bucket_id << (32 - capacity_log_); + } + + OAHEntry Pop() { + for (auto& bucket : entries_) { + if (auto res = bucket.Pop(); !res.Empty()) { + --size_; + return res; + } + } + return {}; + } + + bool Erase(std::string_view str) { + if (entries_.empty()) + return false; + + uint64_t hash = Hash(str); + auto bucket_id = BucketId(hash, capacity_log_); + auto item = FindInternal(bucket_id, str, hash); + if (item != end()) { + *item = OAHEntry(); + return true; + } + return false; + } + + iterator Find(std::string_view member) { + if (entries_.empty()) + return end(); + + uint64_t hash = Hash(member); + auto bucket_id = BucketId(hash, capacity_log_); + auto res = FindInternal(bucket_id, member, hash); + return res; + } + + bool Contains(std::string_view member) { + return Find(member) != end(); + } + + // Returns the number of elements in the map. Note that it might be that some of these elements + // have expired and can't be accessed. + size_t UpperBoundSize() const { + return size_; + } + + bool Empty() const { + return size_ == 0; + } + + std::uint32_t BucketCount() const { + return entries_.size(); // the same as Capacity() + } + + std::uint32_t Capacity() const { + return 1 << capacity_log_; + } + + // set an abstract time that allows expiry. + void set_time(uint32_t val) { + time_now_ = val; + } + + uint32_t time_now() const { + return time_now_; + } + + size_t ObjMallocUsed() const { + // TODO implement + LOG(FATAL) << "ExpirationUsed() isn't implemented"; + return 0; + } + + size_t SetMallocUsed() const { + // TODO implement + LOG(FATAL) << "ExpirationUsed() isn't implemented"; + return 0; + } + + bool ExpirationUsed() const { + // TODO + LOG(FATAL) << "ExpirationUsed() isn't implemented"; + return true; + } + + size_t SizeSlow() { + // TODO + LOG(FATAL) << "SizeSlow() isn't implemented"; + // CollectExpired(); + return size_; + } + + private: + // was Grow in StringSet + void Rehash(uint32_t prev_capacity_log) { + auto prev_size = 1 << prev_capacity_log; + for (int64_t bucket_id = prev_size - 1; bucket_id >= 0; --bucket_id) { + auto bucket = std::move(entries_[bucket_id]); + // TODO add optimization for package processing + for (uint32_t pos = 0, size = bucket.ElementsNum(); pos < size; ++pos) { + // TODO operator [] is inefficient and it is better to avoid it + if (bucket[pos]) { + auto new_bucket_id = + bucket[pos].Rehash(bucket_id, prev_capacity_log, capacity_log_, kShiftLog); + + // TODO add optimization for package processing + new_bucket_id = FindEmptyAround(new_bucket_id); + + // insert method is inefficient + entries_[new_bucket_id]->Insert(std::move(bucket[pos])); + } + } + } + } + + uint32_t EntryTTL(uint32_t ttl_sec) const { + return ttl_sec == UINT32_MAX ? ttl_sec : time_now_ + ttl_sec; + } + + uint32_t FindEmptyAround(uint32_t bid) { + const uint32_t displacement_size = std::min(kDisplacementSize, BucketCount()); + const uint32_t capacity_mask = Capacity() - 1; + for (uint32_t i = 0; i < displacement_size; i++) { + const uint32_t bucket_id = (bid + i) & capacity_mask; + if (entries_[bucket_id].Empty()) + return bucket_id; + // TODO add expiration logic + } + + DCHECK(Capacity() > kDisplacementSize); + uint32_t extension_point_shift = displacement_size - 1; + bid |= extension_point_shift; + DCHECK(bid < Capacity()); + return bid; + } + + // return bucket_id and position otherwise max + iterator FindInternal(uint32_t bid, std::string_view str, uint64_t hash) { + const uint32_t displacement_size = std::min(kDisplacementSize, BucketCount()); + const uint32_t capacity_mask = Capacity() - 1; + for (uint32_t i = 0; i < displacement_size; i++) { + const uint32_t bucket_id = (bid + i) & capacity_mask; + auto pos = entries_[bucket_id].Find(str, hash, capacity_log_, kShiftLog, &size_, time_now_); + if (pos) { + return iterator{this, bucket_id, *pos}; + } + } + return end(); + } + + private: + static constexpr std::uint32_t kMinCapacityLog = 3; // TODO make template + static constexpr std::uint32_t kShiftLog = 4; // TODO make template + static constexpr std::uint32_t kDisplacementSize = (1 << kShiftLog); // TODO check + std::uint32_t capacity_log_ = 0; + std::uint32_t size_ = 0; // number of elements in the set. + std::uint32_t time_now_ = 0; + Buckets entries_; +}; + +} // namespace dfly diff --git a/src/core/oah_set_test.cc b/src/core/oah_set_test.cc new file mode 100644 index 000000000000..e5c0489b8765 --- /dev/null +++ b/src/core/oah_set_test.cc @@ -0,0 +1,855 @@ +// Copyright 2022, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "core/oah_set.h" + +#include +#include +#include + +#include +#include +#include + +#include "base/gtest.h" +#include "core/mi_memory_resource.h" +#include "glog/logging.h" + +extern "C" { +#include "redis/zmalloc.h" +} + +namespace dfly { + +using namespace std; + +class ISSAllocator : public PMR_NS::memory_resource { + public: + bool all_freed() const { + return alloced_ == 0; + } + + void* do_allocate(size_t bytes, size_t alignment) override { + alloced_ += bytes; + void* p = PMR_NS::new_delete_resource()->allocate(bytes, alignment); + return p; + } + + void do_deallocate(void* p, size_t bytes, size_t alignment) override { + alloced_ -= bytes; + return PMR_NS::new_delete_resource()->deallocate(p, bytes, alignment); + } + + bool do_is_equal(const PMR_NS::memory_resource& other) const noexcept override { + return PMR_NS::new_delete_resource()->is_equal(other); + } + + private: + size_t alloced_ = 0; +}; + +class OAHSetTest : public ::testing::Test { + protected: + static void SetUpTestSuite() { + auto* tlh = mi_heap_get_backing(); + init_zmalloc_threadlocal(tlh); + } + + static void TearDownTestSuite() { + } + + void SetUp() override { + ss_ = new OAHSet(&alloc_); + generator_.seed(0); + } + + void TearDown() override { + delete ss_; + + // ensure there are no memory leaks after every test + EXPECT_TRUE(alloc_.all_freed()); + EXPECT_EQ(zmalloc_used_memory_tl, 0); + } + + OAHSet* ss_; + ISSAllocator alloc_; + mt19937 generator_; +}; + +static string random_string(mt19937& rand, unsigned len) { + const string_view alpanum = "1234567890abcdefghijklmnopqrstuvwxyz"; + string ret; + ret.reserve(len); + + for (size_t i = 0; i < len; ++i) { + ret += alpanum[rand() % alpanum.size()]; + } + + return ret; +} + +TEST_F(OAHSetTest, OAHEntryTest) { + OAHEntry test("0123456789", 2); + + EXPECT_EQ(test.Key(), "0123456789"sv); + EXPECT_EQ(test.GetExpiry(), 2); + + OAHEntry first("123456789"); + first.SetHash(Hash(first.Key()), 4, 4); + + test.Insert(std::move(first)); + + uint32_t set_size = 4; + EXPECT_EQ(test.Find("123456789", Hash("123456789"), 4, 4, &set_size), 1); + + test.Insert(OAHEntry("23456789")); + + EXPECT_TRUE(test.Remove(0)); + EXPECT_FALSE(test.Remove(0)); + + EXPECT_EQ(test.Remove(2).Key(), "23456789"); + EXPECT_EQ(test.Pop().Key(), "123456789"); +} + +TEST_F(OAHSetTest, HashCheckTest) { + OAHEntry isl; + { + uint32_t pos = isl.Insert(OAHEntry("0123456789")); + isl[pos].SetHash(Hash(isl[pos].Key()), 3, 4); + EXPECT_TRUE(isl[pos].CheckBucketAffiliation(4, 3, 4)); + EXPECT_FALSE(isl[pos].CheckBucketAffiliation(6, 3, 4)); + EXPECT_TRUE(isl[pos].CheckBucketAffiliation(4, 4, 3)); + EXPECT_FALSE(isl[pos].CheckBucketAffiliation(6, 4, 3)); + } + { + uint32_t pos = isl.Insert(OAHEntry("123456789")); + isl[pos].SetHash(Hash(isl[pos].Key()), 3, 4); + } + { + uint32_t pos = isl.Insert(OAHEntry("23456789")); + isl[pos].SetHash(Hash(isl[pos].Key()), 3, 4); + } + { + uint32_t pos = isl.Insert(OAHEntry("3456789")); + isl[pos].SetHash(Hash(isl[pos].Key()), 3, 4); + } + { + uint32_t pos = isl.Insert(OAHEntry("456789")); + isl[pos].SetHash(Hash(isl[pos].Key()), 3, 4); + } + + uint32_t num_expired_fields = 0; + + EXPECT_TRUE(isl.Find("0123456789", Hash("0123456789"), 3, 4, &num_expired_fields)); + EXPECT_TRUE(isl.Find("123456789", Hash("123456789"), 3, 4, &num_expired_fields)); + EXPECT_TRUE(isl.Find("23456789", Hash("23456789"), 3, 4, &num_expired_fields)); + EXPECT_TRUE(isl.Find("3456789", Hash("3456789"), 3, 4, &num_expired_fields)); + EXPECT_TRUE(isl.Find("456789", Hash("456789"), 3, 4, &num_expired_fields)); + + auto idx = isl.Find("456789", Hash("456789"), 3, 4, &num_expired_fields); + auto new_pos = isl[*idx].Rehash(7, 3, 4, 4); + EXPECT_EQ(new_pos, 6); + EXPECT_FALSE(isl[*idx].GetHash()); +} + +TEST_F(OAHSetTest, OAHSetAddFindTest) { + OAHSet ss; + std::set test_set; + + for (int i = 0; i < 10000; ++i) { + test_set.insert(base::RandStr(20)); + } + + for (const auto& s : test_set) { + EXPECT_TRUE(ss.Add(s)); + } + + for (const auto& s : test_set) { + auto e = ss.Find(s); + EXPECT_EQ(e->Key(), s); + } + + EXPECT_EQ(ss.Capacity(), 16384); +} + +TEST_F(OAHSetTest, Basic) { + EXPECT_TRUE(ss_->Add("foo"sv)); + EXPECT_TRUE(ss_->Add("bar"sv)); + uint32_t size = ss_->UpperBoundSize(); + EXPECT_FALSE(ss_->Add("foo"sv)); + EXPECT_FALSE(ss_->Add("bar"sv)); + EXPECT_EQ(ss_->UpperBoundSize(), size); + EXPECT_TRUE(ss_->Contains("foo"sv)); + EXPECT_TRUE(ss_->Contains("bar"sv)); + EXPECT_EQ(2, ss_->UpperBoundSize()); +} + +TEST_F(OAHSetTest, StandardAddErase) { + EXPECT_TRUE(ss_->Add("@@@@@@@@@@@@@@@@") != ss_->end()); + EXPECT_TRUE(ss_->Add("A@@@@@@@@@@@@@@@") != ss_->end()); + EXPECT_TRUE(ss_->Add("AA@@@@@@@@@@@@@@") != ss_->end()); + EXPECT_TRUE(ss_->Add("AAA@@@@@@@@@@@@@") != ss_->end()); + EXPECT_TRUE(ss_->Add("AAAAAAAAA@@@@@@@") != ss_->end()); + EXPECT_TRUE(ss_->Add("AAAAAAAAAA@@@@@@") != ss_->end()); + EXPECT_TRUE(ss_->Add("AAAAAAAAAAAAAAA@") != ss_->end()); + EXPECT_TRUE(ss_->Add("AAAAAAAAAAAAAAAA") != ss_->end()); + EXPECT_TRUE(ss_->Add("AAAAAAAAAAAAAAAD") != ss_->end()); + EXPECT_TRUE(ss_->Add("BBBBBAAAAAAAAAAA") != ss_->end()); + EXPECT_TRUE(ss_->Add("BBBBBBBBAAAAAAAA") != ss_->end()); + EXPECT_TRUE(ss_->Add("CCCCCBBBBBBBBBBB") != ss_->end()); + + // Remove link in the middle of chain + EXPECT_TRUE(ss_->Erase("BBBBBBBBAAAAAAAA")); + // Remove start of a chain + EXPECT_TRUE(ss_->Erase("CCCCCBBBBBBBBBBB")); + // Remove end of link + EXPECT_TRUE(ss_->Erase("AAA@@@@@@@@@@@@@")); + // Remove only item in chain + EXPECT_TRUE(ss_->Erase("AA@@@@@@@@@@@@@@")); + EXPECT_TRUE(ss_->Erase("AAAAAAAAA@@@@@@@")); + EXPECT_TRUE(ss_->Erase("AAAAAAAAAA@@@@@@")); + EXPECT_TRUE(ss_->Erase("AAAAAAAAAAAAAAA@")); +} + +TEST_F(OAHSetTest, DisplacedBug) { + string_view vals[] = {"imY", "OVl", "NhH", "BCe", "YDL", "lpb", + "nhF", "xod", "zYR", "PSa", "hce", "cTR"}; + ss_->AddMany(absl::MakeSpan(vals), UINT32_MAX); + + ss_->Add("fIc"); + ss_->Erase("YDL"); + ss_->Add("fYs"); + ss_->Erase("hce"); + ss_->Erase("nhF"); + ss_->Add("dye"); + ss_->Add("xZT"); + ss_->Add("LVK"); + ss_->Erase("zYR"); + ss_->Erase("fYs"); + ss_->Add("ueB"); + ss_->Erase("PSa"); + ss_->Erase("OVl"); + ss_->Add("cga"); + ss_->Add("too"); + ss_->Erase("ueB"); + ss_->Add("HZe"); + ss_->Add("oQn"); + ss_->Erase("too"); + ss_->Erase("HZe"); + ss_->Erase("xZT"); + ss_->Erase("cga"); + ss_->Erase("cTR"); + ss_->Erase("BCe"); + ss_->Add("eua"); + ss_->Erase("lpb"); + ss_->Add("OXK"); + ss_->Add("QmO"); + ss_->Add("SzV"); + ss_->Erase("QmO"); + ss_->Add("jbe"); + ss_->Add("BPN"); + ss_->Add("OfH"); + ss_->Add("Muf"); + ss_->Add("CwP"); + ss_->Erase("Muf"); + ss_->Erase("xod"); + ss_->Add("Cis"); + ss_->Add("Xvd"); + ss_->Erase("SzV"); + ss_->Erase("eua"); + ss_->Add("DGb"); + ss_->Add("leD"); + ss_->Add("MVX"); + ss_->Add("HPq"); +} + +TEST_F(OAHSetTest, Resizing) { + constexpr size_t num_strs = 4096; + unordered_set strs; + while (strs.size() != num_strs) { + auto str = random_string(generator_, 10); + strs.insert(str); + } + + unsigned size = 0; + for (auto it = strs.begin(); it != strs.end(); ++it) { + const auto& str = *it; + EXPECT_TRUE(ss_->Add(str, 1)); + EXPECT_EQ(ss_->UpperBoundSize(), size + 1); + + // make sure we haven't lost any items after a grow + // which happens every power of 2 + if ((size & (size - 1)) == 0) { + for (auto j = strs.begin(); j != it; ++j) { + const auto& str = *j; + auto it = ss_->Find(str); + ASSERT_NE(it, ss_->end()); + EXPECT_TRUE(it.HasExpiry()); + EXPECT_EQ(it.ExpiryTime(), ss_->time_now() + 1); + } + } + ++size; + } +} + +TEST_F(OAHSetTest, SimpleScan) { + unordered_set info = {"foo", "bar"}; + unordered_set seen; + + for (auto str : info) { + EXPECT_TRUE(ss_->Add(str)); + } + + uint32_t cursor = 0; + do { + cursor = ss_->Scan(cursor, [&](std::string_view str) { + EXPECT_TRUE(info.count(str)); + seen.insert(str); + }); + } while (cursor != 0); + + EXPECT_EQ(seen.size(), info.size()); + EXPECT_TRUE(equal(seen.begin(), seen.end(), info.begin())); +} + +// // Ensure REDIS scan guarantees are met +TEST_F(OAHSetTest, ScanGuarantees) { + unordered_set to_be_seen = {"foo", "bar"}; + unordered_set not_be_seen = {"AAA", "BBB"}; + unordered_set maybe_seen = {"AA@@@@@@@@@@@@@@", "AAA@@@@@@@@@@@@@", + "AAAAAAAAA@@@@@@@", "AAAAAAAAAA@@@@@@"}; + unordered_set seen; + + auto scan_callback = [&](std::string_view str) { + EXPECT_TRUE(to_be_seen.count(str) || maybe_seen.count(str)); + EXPECT_FALSE(not_be_seen.count(str)); + if (to_be_seen.count(str)) { + seen.insert(str); + } + }; + + EXPECT_EQ(ss_->Scan(0, scan_callback), 0); + + for (auto str : not_be_seen) { + EXPECT_TRUE(ss_->Add(str)); + } + + for (auto str : not_be_seen) { + EXPECT_TRUE(ss_->Erase(str)); + } + + for (auto str : to_be_seen) { + EXPECT_TRUE(ss_->Add(str)); + } + + // should reach at least the first item in the set + uint32_t cursor = ss_->Scan(0, scan_callback); + + for (auto str : maybe_seen) { + EXPECT_TRUE(ss_->Add(str)); + } + + while (cursor != 0) { + cursor = ss_->Scan(cursor, scan_callback); + } + + EXPECT_TRUE(seen.size() == to_be_seen.size()); +} + +TEST_F(OAHSetTest, IntOnly) { + constexpr size_t num_ints = 8192; + unordered_set numbers; + for (size_t i = 0; i < num_ints; ++i) { + numbers.insert(i); + EXPECT_TRUE(ss_->Add(to_string(i))); + } + EXPECT_EQ(ss_->UpperBoundSize(), num_ints); + + for (size_t i = 0; i < num_ints; ++i) { + ASSERT_FALSE(ss_->Add(to_string(i))); + } + EXPECT_EQ(ss_->UpperBoundSize(), num_ints); + + size_t num_remove = generator_() % 4096; + unordered_set removed; + + for (size_t i = 0; i < num_remove; ++i) { + auto remove_int = generator_() % num_ints; + auto remove = to_string(remove_int); + if (numbers.count(remove_int)) { + ASSERT_TRUE(ss_->Contains(remove)) << remove_int; + EXPECT_TRUE(ss_->Erase(remove)); + numbers.erase(remove_int); + } else { + EXPECT_FALSE(ss_->Erase(remove)); + } + + EXPECT_FALSE(ss_->Contains(remove)); + removed.insert(remove); + } + + size_t expected_seen = 0; + auto scan_callback = [&](std::string_view str_v) { + std::string str(str_v); + EXPECT_FALSE(removed.count(str)); + + if (numbers.count(std::atoi(str.data()))) { + ++expected_seen; + } + }; + + uint32_t cursor = 0; + do { + cursor = ss_->Scan(cursor, scan_callback); + // randomly throw in some new numbers + uint32_t val = generator_(); + ss_->Add(to_string(val)); + } while (cursor != 0); + + EXPECT_GE(expected_seen + removed.size(), num_ints); +} + +TEST_F(OAHSetTest, XtremeScanGrow) { + unordered_set to_see, force_grow, seen; + + while (to_see.size() != 8) { + to_see.insert(random_string(generator_, 10)); + } + + while (force_grow.size() != 8192) { + string str = random_string(generator_, 10); + + if (to_see.count(str)) { + continue; + } + + force_grow.insert(random_string(generator_, 10)); + } + + for (auto& str : to_see) { + EXPECT_TRUE(ss_->Add(str)); + } + + auto scan_callback = [&](string_view strv) { + std::string str(strv); + if (to_see.count(str)) { + seen.insert(str); + } + }; + + uint32_t cursor = ss_->Scan(0, scan_callback); + + // force approx 10 grows + for (auto& s : force_grow) { + EXPECT_TRUE(ss_->Add(s)); + } + + while (cursor != 0) { + cursor = ss_->Scan(cursor, scan_callback); + } + + EXPECT_EQ(seen.size(), to_see.size()); +} + +TEST_F(OAHSetTest, Pop) { + constexpr size_t num_items = 8; + unordered_set to_insert; + + while (to_insert.size() != num_items) { + auto str = random_string(generator_, 10); + if (to_insert.count(str)) { + continue; + } + + to_insert.insert(str); + EXPECT_TRUE(ss_->Add(str)); + } + + while (!ss_->Empty()) { + size_t size = ss_->UpperBoundSize(); + auto str = ss_->Pop(); + DCHECK(ss_->UpperBoundSize() == to_insert.size() - 1); + DCHECK(str); + DCHECK(to_insert.count(std::string(str.Key()))); + DCHECK_EQ(ss_->UpperBoundSize(), size - 1); + to_insert.erase(std::string(str.Key())); + } + + DCHECK(ss_->Empty()); + DCHECK(to_insert.empty()); +} + +TEST_F(OAHSetTest, Iteration) { + ss_->Add("foo"); + for (const auto& ptr : *ss_) { + LOG(INFO) << ptr; + } + ss_->Clear(); + constexpr size_t num_items = 8192; + unordered_set to_insert; + + while (to_insert.size() != num_items) { + auto str = random_string(generator_, 10); + if (to_insert.count(str)) { + continue; + } + + to_insert.insert(str); + EXPECT_TRUE(ss_->Add(str)); + } + + for (const auto& ptr : *ss_) { + std::string str(ptr.Key()); + EXPECT_TRUE(to_insert.count(str)); + to_insert.erase(str); + } + + EXPECT_EQ(to_insert.size(), 0); +} + +TEST_F(OAHSetTest, SetFieldExpireHasExpiry) { + EXPECT_TRUE(ss_->Add("k1", 100)); + auto k = ss_->Find("k1"); + EXPECT_TRUE(k.HasExpiry()); + EXPECT_EQ(k.ExpiryTime(), 100); + size_t obj_malloc_used; + k.SetExpiryTime(1, &obj_malloc_used); + EXPECT_TRUE(k.HasExpiry()); + EXPECT_EQ(k.ExpiryTime(), 1); +} + +TEST_F(OAHSetTest, SetFieldExpireNoHasExpiry) { + EXPECT_TRUE(ss_->Add("k1")); + auto k = ss_->Find("k1"); + EXPECT_FALSE(k.HasExpiry()); + size_t obj_malloc_used; + k.SetExpiryTime(10, &obj_malloc_used); + EXPECT_TRUE(k.HasExpiry()); + EXPECT_EQ(k.ExpiryTime(), 10); +} + +TEST_F(OAHSetTest, Ttl) { + EXPECT_TRUE(ss_->Add("bla"sv, 1)); + EXPECT_FALSE(ss_->Add("bla"sv, 1)); + auto it = ss_->Find("bla"sv); + EXPECT_EQ(1u, it.ExpiryTime()); + + ss_->set_time(1); + EXPECT_TRUE(ss_->Add("bla"sv, 1)); + EXPECT_EQ(1u, ss_->UpperBoundSize()); + + for (unsigned i = 0; i < 100; ++i) { + EXPECT_TRUE(ss_->Add(absl::StrCat("foo", i), 1)); + } + EXPECT_EQ(101u, ss_->UpperBoundSize()); + it = ss_->Find("foo50"); + EXPECT_EQ("foo50"sv, it->Key()); + EXPECT_EQ(2u, it.ExpiryTime()); + + ss_->set_time(2); + // Cleanup all `foo` entries + uint32_t cursor = 0; + do { + cursor = ss_->Scan(cursor, [&](std::string_view) {}); + } while (cursor != 0); + + for (unsigned i = 0; i < 100; ++i) { + EXPECT_TRUE(ss_->Add(absl::StrCat("bar", i))); + } + EXPECT_EQ(100u, ss_->UpperBoundSize()); + it = ss_->Find("bar50"); + EXPECT_FALSE(it.HasExpiry()); + + for (auto it = ss_->begin(); it != ss_->end(); ++it) { + ASSERT_TRUE(absl::StartsWith(it->Key(), "bar")) << it->Key(); + string str(it->Key()); + VLOG(1) << *it; + } +} + +TEST_F(OAHSetTest, Grow) { + for (size_t j = 0; j < 10; ++j) { + for (size_t i = 0; i < 4098; ++i) { + ss_->Reserve(generator_() % 256); + auto str = random_string(generator_, 3); + ss_->Add(str); + } + ss_->Clear(); + } +} + +TEST_F(OAHSetTest, Reserve) { + vector strs; + + for (size_t i = 0; i < 10; ++i) { + strs.push_back(random_string(generator_, 10)); + ss_->Add(strs.back()); + } + + for (size_t j = 2; j < 20; j += 3) { + ss_->Reserve(j * 20); + for (size_t i = 0; i < 10; ++i) { + ASSERT_TRUE(ss_->Contains(strs[i])); + } + } +} + +TEST_F(OAHSetTest, Fill) { + for (size_t i = 0; i < 100; ++i) { + ss_->Add(random_string(generator_, 10)); + } + OAHSet s2; + ss_->Fill(&s2); + EXPECT_EQ(s2.UpperBoundSize(), ss_->UpperBoundSize()); + for (const auto& s : *ss_) { + EXPECT_TRUE(s2.Contains(s.Key())); + } +} + +TEST_F(OAHSetTest, IterateEmpty) { + for (const auto& s : *ss_) { + // We're iterating to make sure there is no crash. However, if we got here, it's a bug + CHECK(false) << "Found entry " << s << " in empty set"; + } +} + +// size_t memUsed(OAHSet& obj) { +// return obj.ObjMallocUsed() + obj.SetMallocUsed(); +// } + +void BM_Clone(benchmark::State& state) { + vector strs; + mt19937 generator(0); + OAHSet ss1, ss2; + unsigned elems = state.range(0); + for (size_t i = 0; i < elems; ++i) { + string str = random_string(generator, 10); + ss1.Add(str); + } + ss2.Reserve(ss1.UpperBoundSize()); + while (state.KeepRunning()) { + for (auto& src : ss1) { + ss2.Add(src.Key()); + } + state.PauseTiming(); + ss2.Clear(); + ss2.Reserve(ss1.UpperBoundSize()); + state.ResumeTiming(); + } +} +BENCHMARK(BM_Clone)->ArgName("elements")->Arg(32000); + +void BM_Fill(benchmark::State& state) { + unsigned elems = state.range(0); + vector strs; + mt19937 generator(0); + OAHSet ss1, ss2; + for (size_t i = 0; i < elems; ++i) { + string str = random_string(generator, 10); + ss1.Add(str); + } + + while (state.KeepRunning()) { + ss1.Fill(&ss2); + state.PauseTiming(); + ss2.Clear(); + state.ResumeTiming(); + } +} +BENCHMARK(BM_Fill)->ArgName("elements")->Arg(32000); + +void BM_Clear(benchmark::State& state) { + unsigned elems = state.range(0); + mt19937 generator(0); + OAHSet ss; + while (state.KeepRunning()) { + state.PauseTiming(); + for (size_t i = 0; i < elems; ++i) { + string str = random_string(generator, 16); + ss.Add(str); + } + state.ResumeTiming(); + ss.Clear(); + } +} +BENCHMARK(BM_Clear)->ArgName("elements")->Arg(32000); + +void BM_Add(benchmark::State& state) { + vector strs; + mt19937 generator(0); + OAHSet ss; + unsigned elems = state.range(0); + unsigned keySize = state.range(1); + for (size_t i = 0; i < elems; ++i) { + string str = random_string(generator, keySize); + strs.push_back(str); + } + ss.Reserve(elems); + while (state.KeepRunning()) { + for (auto& str : strs) + ss.Add(str); + state.PauseTiming(); + // state.counters["Memory_Used"] = memUsed(ss); + ss.Clear(); + ss.Reserve(elems); + state.ResumeTiming(); + } +} +BENCHMARK(BM_Add) + ->ArgNames({"elements", "Key Size"}) + ->ArgsProduct({{1000, 10000, 100000}, {10, 100, 1000}}); + +void BM_AddMany(benchmark::State& state) { + vector strs; + mt19937 generator(0); + OAHSet ss; + unsigned elems = state.range(0); + unsigned keySize = state.range(1); + for (size_t i = 0; i < elems; ++i) { + string str = random_string(generator, keySize); + strs.push_back(str); + } + ss.Reserve(elems); + vector svs; + for (const auto& str : strs) { + svs.push_back(str); + } + while (state.KeepRunning()) { + ss.AddMany(absl::MakeSpan(svs)); + state.PauseTiming(); + CHECK_EQ(ss.UpperBoundSize(), elems); + // state.counters["Memory_Used"] = memUsed(ss); + ss.Clear(); + ss.Reserve(elems); + state.ResumeTiming(); + } +} +BENCHMARK(BM_AddMany) + ->ArgNames({"elements", "Key Size"}) + ->ArgsProduct({{1000, 10000, 100000}, {10, 100, 1000}}); + +void BM_Erase(benchmark::State& state) { + std::vector strs; + mt19937 generator(0); + OAHSet ss; + auto elems = state.range(0); + auto keySize = state.range(1); + for (long int i = 0; i < elems; ++i) { + std::string str = random_string(generator, keySize); + strs.push_back(str); + ss.Add(str); + } + // state.counters["Memory_Before_Erase"] = memUsed(ss); + while (state.KeepRunning()) { + for (auto& str : strs) { + ss.Erase(str); + } + state.PauseTiming(); + // state.counters["Memory_After_Erase"] = memUsed(ss); + for (auto& str : strs) { + ss.Add(str); + } + state.ResumeTiming(); + } +} +BENCHMARK(BM_Erase) + ->ArgNames({"elements", "Key Size"}) + ->ArgsProduct({{1000, 10000, 100000}, {10, 100, 1000}}); + +void BM_Get(benchmark::State& state) { + std::vector strs; + mt19937 generator(0); + OAHSet ss; + auto elems = state.range(0); + auto keySize = state.range(1); + for (long int i = 0; i < elems; ++i) { + std::string str = random_string(generator, keySize); + strs.push_back(str); + ss.Add(str); + } + while (state.KeepRunning()) { + for (auto& str : strs) { + ss.Find(str); + } + } +} +BENCHMARK(BM_Get) + ->ArgNames({"elements", "Key Size"}) + ->ArgsProduct({{1000, 10000, 100000}, {10, 100, 1000}}); + +void BM_Grow(benchmark::State& state) { + vector strs; + mt19937 generator(0); + OAHSet src; + unsigned elems = 1 << 18; + for (size_t i = 0; i < elems; ++i) { + src.Add(random_string(generator, 16), UINT32_MAX); + strs.push_back(random_string(generator, 16)); + } + + while (state.KeepRunning()) { + state.PauseTiming(); + OAHSet tmp; + src.Fill(&tmp); + CHECK_EQ(tmp.Capacity(), elems); + state.ResumeTiming(); + for (const auto& str : strs) { + tmp.Add(str); + if (tmp.Capacity() > elems) { + break; // we grew + } + } + + CHECK_GT(tmp.Capacity(), elems); + } +} +BENCHMARK(BM_Grow); + +// unsigned total_wasted_memory = 0; + +// TEST_F(OAHSetTest, ReallocIfNeeded) { +// auto build_str = [](size_t i) { return to_string(i) + string(131, 'a'); }; + +// auto count_waste = [](const mi_heap_t* heap, const mi_heap_area_t* area, void* block, +// size_t block_size, void* arg) { +// size_t used = block_size * area->used; +// total_wasted_memory += area->committed - used; +// return true; +// }; + +// for (size_t i = 0; i < 10'000; i++) +// ss_->Add(build_str(i)); + +// for (size_t i = 0; i < 10'000; i++) { +// if (i % 10 == 0) +// continue; +// ss_->Erase(build_str(i)); +// } + +// mi_heap_collect(mi_heap_get_backing(), true); +// mi_heap_visit_blocks(mi_heap_get_backing(), false, count_waste, nullptr); +// size_t wasted_before = total_wasted_memory; + +// size_t underutilized = 0; +// for (auto it = ss_->begin(); it != ss_->end(); ++it) { +// underutilized += zmalloc_page_is_underutilized(*it, 0.9); +// it.ReallocIfNeeded(0.9); +// } +// // Check there are underutilized pages +// CHECK_GT(underutilized, 0u); + +// total_wasted_memory = 0; +// mi_heap_collect(mi_heap_get_backing(), true); +// mi_heap_visit_blocks(mi_heap_get_backing(), false, count_waste, nullptr); +// size_t wasted_after = total_wasted_memory; + +// // Check we waste significanlty less now +// EXPECT_GT(wasted_before, wasted_after * 2); + +// EXPECT_EQ(ss_->UpperBoundSize(), 1000); +// for (size_t i = 0; i < 1000; i++) +// EXPECT_EQ(*ss_->Find(build_str(i * 10)), build_str(i * 10)); +// } + +} // namespace dfly