Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 262 additions & 16 deletions merge_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ func NewMergeIterator(db *DB, iterators []*iterator, ts int64, ascending bool) (
db: db,
}

// Copy iterators for potential re-initialization
copy(mi.allIterators, iterators)

// Initialize each iterator and add to appropriate heap based on direction
Expand Down Expand Up @@ -110,7 +109,7 @@ func (mi *MergeIterator) initializeIterator(it *iterator) error {
} else {
// If peek failed or timestamp is not visible, try to find a valid position
for {
key, value, ts, ok := t.Prev()
key, value, ts, ok = t.Prev()
if !ok {
it.exhausted = true
break
Expand Down Expand Up @@ -161,7 +160,7 @@ func (mi *MergeIterator) initializeIterator(it *iterator) error {
} else {
// If peek failed or timestamp is not visible, try to find a valid position
for {
key, value, ts, ok := t.Prev()
key, value, ts, ok = t.Prev()
if !ok {
it.exhausted = true
break
Expand Down Expand Up @@ -213,7 +212,7 @@ func (mi *MergeIterator) initializeIterator(it *iterator) error {
} else {
// If peek failed or timestamp is not visible, try to find a valid position
for {
key, value, ts, ok := t.Prev()
key, value, ts, ok = t.Prev()
if !ok {
it.exhausted = true
break
Expand Down Expand Up @@ -281,7 +280,7 @@ func (mi *MergeIterator) initializeIterator(it *iterator) error {

// Find a valid position going backwards
for t.Prev() {
entry, err := mi.extractKLogEntry(t.Value())
entry, err = mi.extractKLogEntry(t.Value())
if err != nil {
if it.sst != nil {
mi.db.log(fmt.Sprintf("Potential block corruption detected for SSTable %d at Level %d: %v", it.sst.Id, it.sst.Level, err))
Expand Down Expand Up @@ -354,19 +353,91 @@ func (mi *MergeIterator) extractKLogEntry(value interface{}) (*KLogEntry, error)
// SetDirection changes the iteration direction
func (mi *MergeIterator) SetDirection(ascending bool) error {
if mi.ascending == ascending {
return nil // No change needed
return nil
}

mi.ascending = ascending

// Clear current heaps
mi.heap = make(iteratorHeap, 0, len(mi.allIterators))
mi.reverseHeap = make(reverseIteratorHeap, 0, len(mi.allIterators))

// Re-initialize all iterators for the new direction
for _, it := range mi.allIterators {
it.ascending = ascending
it.exhausted = false // Reset exhausted state

if !it.exhausted && len(it.currentKey) > 0 {
// Skip nil iterators
if it.underlyingIterator == nil {
it.exhausted = true
continue
}

func() {
defer func() {
if r := recover(); r != nil {
it.exhausted = true
}
}()

switch t := it.underlyingIterator.(type) {
case *skiplist.Iterator:
if t == nil {
it.exhausted = true
return
}
if ascending {
t.ToFirst()
} else {
t.ToLast()
}
case *skiplist.RangeIterator:
if t == nil {
it.exhausted = true
return
}
if ascending {
t.ToFirst()
} else {
t.ToLast()
}
case *skiplist.PrefixIterator:
if t == nil {
it.exhausted = true
return
}
if ascending {
t.ToFirst()
} else {
t.ToLast()
}
case *tree.Iterator:
if t == nil {
it.exhausted = true
return
}
var err error
if ascending {
err = t.SeekToFirst()
} else {
err = t.SeekToLast()
}
if err != nil {
it.exhausted = true
return
}
default:
it.exhausted = true
return
}
}()

// Re-initialize the iterator with its first valid entry
if err := mi.initializeIteratorAfterPositioning(it); err != nil {
return err
}

// Add to appropriate heap if not exhausted
if !it.exhausted {
if ascending {
heap.Push(&mi.heap, it)
} else {
Expand All @@ -378,12 +449,191 @@ func (mi *MergeIterator) SetDirection(ascending bool) error {
return nil
}

// initializeIteratorAfterPositioning sets up the iterator after it has been positioned
// This is different from initializeIterator because the iterator is already positioned
func (mi *MergeIterator) initializeIteratorAfterPositioning(it *iterator) error {
if it.sst != nil {
atomic.CompareAndSwapInt32(&it.sst.isBeingRead, 0, 1)
}

switch t := it.underlyingIterator.(type) {
case *skiplist.Iterator:
if t == nil {
it.exhausted = true
return nil
}

if it.ascending {
key, value, ts, ok := t.Next()
if ok && ts <= mi.ts {
it.currentKey = key
it.currentValue = value
it.currentTimestamp = ts
} else {
it.exhausted = true
}
} else {
key, value, ts, ok := t.Peek()
if ok && ts <= mi.ts {
it.currentKey = key
it.currentValue = value
it.currentTimestamp = ts
} else {
key, value, ts, ok = t.Prev()
if ok && ts <= mi.ts {
it.currentKey = key
it.currentValue = value
it.currentTimestamp = ts
} else {
it.exhausted = true
}
}
}

case *skiplist.RangeIterator:
if t == nil {
it.exhausted = true
return nil
}

if it.ascending {
key, value, ts, ok := t.Next()
if ok && ts <= mi.ts {
it.currentKey = key
it.currentValue = value
it.currentTimestamp = ts
} else {
it.exhausted = true
}
} else {
key, value, ts, ok := t.Peek()
if ok && ts <= mi.ts {
it.currentKey = key
it.currentValue = value
it.currentTimestamp = ts
} else {
key, value, ts, ok = t.Prev()
if ok && ts <= mi.ts {
it.currentKey = key
it.currentValue = value
it.currentTimestamp = ts
} else {
it.exhausted = true
}
}
}

case *skiplist.PrefixIterator:
if t == nil {
it.exhausted = true
return nil
}

if it.ascending {
key, value, ts, ok := t.Next()
if ok && ts <= mi.ts {
it.currentKey = key
it.currentValue = value
it.currentTimestamp = ts
} else {
it.exhausted = true
}
} else {
key, value, ts, ok := t.Peek()
if ok && ts <= mi.ts {
it.currentKey = key
it.currentValue = value
it.currentTimestamp = ts
} else {
key, value, ts, ok = t.Prev()
if ok && ts <= mi.ts {
it.currentKey = key
it.currentValue = value
it.currentTimestamp = ts
} else {
it.exhausted = true
}
}
}

case *tree.Iterator:
if t == nil {
it.exhausted = true
return nil
}

if t.Valid() {
entry, err := mi.extractKLogEntry(t.Value())
if err != nil {
if it.sst != nil {
mi.db.log(fmt.Sprintf("Potential block corruption detected for SSTable %d at Level %d: %v", it.sst.Id, it.sst.Level, err))
}
it.exhausted = true
return err
}

if entry.Timestamp <= mi.ts {
it.currentKey = entry.Key
it.currentValue = it.sst.readValueFromVLog(entry.ValueBlockID)
it.currentTimestamp = entry.Timestamp
} else {
found := false
if it.ascending {
for t.Next() {
entry, err = mi.extractKLogEntry(t.Value())
if err != nil {
if it.sst != nil {
mi.db.log(fmt.Sprintf("Potential block corruption detected for SSTable %d at Level %d: %v", it.sst.Id, it.sst.Level, err))
}
it.exhausted = true
return err
}
if entry.Timestamp <= mi.ts {
it.currentKey = entry.Key
it.currentValue = it.sst.readValueFromVLog(entry.ValueBlockID)
it.currentTimestamp = entry.Timestamp
found = true
break
}
}
} else {
for t.Prev() {
entry, err = mi.extractKLogEntry(t.Value())
if err != nil {
if it.sst != nil {
mi.db.log(fmt.Sprintf("Potential block corruption detected for SSTable %d at Level %d: %v", it.sst.Id, it.sst.Level, err))
}
it.exhausted = true
return err
}
if entry.Timestamp <= mi.ts {
it.currentKey = entry.Key
it.currentValue = it.sst.readValueFromVLog(entry.ValueBlockID)
it.currentTimestamp = entry.Timestamp
found = true
break
}
}
}
if !found {
it.exhausted = true
}
}
} else {
it.exhausted = true
}

default:
it.exhausted = true
}

return nil
}

// seekIterator positions an iterator at or near the given key
func (mi *MergeIterator) seekIterator(it *iterator, seekKey []byte) error {
switch t := it.underlyingIterator.(type) {
case *skiplist.Iterator:
// SkipList iterators don't have a direct seek method
// We need to iterate until we find the right position
it.exhausted = true
for {
var key []byte
Expand Down Expand Up @@ -416,7 +666,6 @@ func (mi *MergeIterator) seekIterator(it *iterator, seekKey []byte) error {
case *tree.Iterator:
if err := t.Seek(seekKey); err != nil {
it.exhausted = true
// Release the isBeingRead flag on error
if it.sst != nil {
atomic.CompareAndSwapInt32(&it.sst.isBeingRead, 1, 0)
}
Expand All @@ -429,7 +678,6 @@ func (mi *MergeIterator) seekIterator(it *iterator, seekKey []byte) error {
if err != nil {
if it.sst != nil {
mi.db.log(fmt.Sprintf("Potential block corruption detected for SSTable %d at Level %d: %v", it.sst.Id, it.sst.Level, err))
// Release the isBeingRead flag on error
atomic.CompareAndSwapInt32(&it.sst.isBeingRead, 1, 0)
}
it.exhausted = true
Expand All @@ -443,14 +691,14 @@ func (mi *MergeIterator) seekIterator(it *iterator, seekKey []byte) error {
it.exhausted = false
} else {
it.exhausted = true
// Release the isBeingRead flag if we can't find a valid entry

if it.sst != nil {
atomic.CompareAndSwapInt32(&it.sst.isBeingRead, 1, 0)
}
}
} else {
it.exhausted = true
// Release the isBeingRead flag if iterator is invalid

if it.sst != nil {
atomic.CompareAndSwapInt32(&it.sst.isBeingRead, 1, 0)
}
Expand Down Expand Up @@ -547,13 +795,11 @@ func (mi *MergeIterator) Prev() ([]byte, []byte, int64, bool) {
}

if mi.ascending {
// If configured for ascending, Prev means descending
if err := mi.SetDirection(false); err != nil {
return nil, nil, 0, false
}
return mi.nextDescending()
} else {
// If configured for descending, Prev means ascending
if err := mi.SetDirection(true); err != nil {
return nil, nil, 0, false
}
Expand Down
Loading