Skip to content

Commit

Permalink
Add PutWithTTL (#266)
Browse files Browse the repository at this point in the history
1. add PutWithTTL to support key expire
2. fix iterator functions, prase the raw value from it
3. fix merge rebuild index
  • Loading branch information
roseduan authored Aug 21, 2023
1 parent d72f9b9 commit f36db24
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 50 deletions.
68 changes: 60 additions & 8 deletions batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package rosedb
import (
"fmt"
"sync"
"time"

"github.com/bwmarrin/snowflake"

"github.com/rosedblabs/wal"
)

Expand Down Expand Up @@ -112,9 +112,35 @@ func (b *Batch) Put(key []byte, value []byte) error {
b.mu.Lock()
// write to pendingWrites
b.pendingWrites[string(key)] = &LogRecord{
Key: key,
Value: value,
Type: LogRecordNormal,
Key: key,
Value: value,
Type: LogRecordNormal,
Expire: 0,
}
b.mu.Unlock()

return nil
}

// PutWithTTL adds a key-value pair with ttl to the batch for writing.
func (b *Batch) PutWithTTL(key []byte, value []byte, ttl time.Duration) error {
if len(key) == 0 {
return ErrKeyIsEmpty
}
if b.db.closed {
return ErrDBClosed
}
if b.options.ReadOnly {
return ErrReadOnlyBatch
}

b.mu.Lock()
// write to pendingWrites
b.pendingWrites[string(key)] = &LogRecord{
Key: key,
Value: value,
Type: LogRecordNormal,
Expire: time.Now().Add(ttl).UnixNano(),
}
b.mu.Unlock()

Expand All @@ -130,11 +156,12 @@ func (b *Batch) Get(key []byte) ([]byte, error) {
return nil, ErrDBClosed
}

now := time.Now().UnixNano()
// get from pendingWrites
if b.pendingWrites != nil {
b.mu.RLock()
if record := b.pendingWrites[string(key)]; record != nil {
if record.Type == LogRecordDeleted {
if record.Type == LogRecordDeleted || (record.Expire > 0 && record.Expire <= now) {
b.mu.RUnlock()
return nil, ErrKeyNotFound
}
Expand All @@ -154,10 +181,14 @@ func (b *Batch) Get(key []byte) ([]byte, error) {
return nil, err
}

// check if the record is deleted or expired
record := decodeLogRecord(chunk)
if record.Type == LogRecordDeleted {
panic("Deleted data cannot exist in the index")
}
if record.Expire > 0 && record.Expire <= now {
return nil, ErrKeyNotFound
}
return record.Value, nil
}

Expand Down Expand Up @@ -207,9 +238,24 @@ func (b *Batch) Exist(key []byte) (bool, error) {
b.mu.RUnlock()
}

// check if the key exists in data file
// check if the key exists in index
position := b.db.index.Get(key)
return position != nil, nil
if position == nil {
return false, nil
}

// check if the record is deleted or expired
chunk, err := b.db.dataFiles.Read(position)
if err != nil {
return false, err
}

now := time.Now().UnixNano()
record := decodeLogRecord(chunk)
if record.Type == LogRecordDeleted || (record.Expire > 0 && record.Expire <= now) {
return false, nil
}
return true, nil
}

// Commit commits the batch, if the batch is readonly or empty, it will return directly.
Expand Down Expand Up @@ -241,8 +287,14 @@ func (b *Batch) Commit() error {
batchId := b.batchId.Generate()
positions := make(map[string]*wal.ChunkPosition)

now := time.Now().UnixNano()
// write to wal
for _, record := range b.pendingWrites {
// skip the expired record
if record.Expire > 0 && record.Expire <= now {
continue
}

record.BatchId = uint64(batchId)
encRecord := encodeLogRecord(record)
pos, err := b.db.dataFiles.Write(encRecord)
Expand Down Expand Up @@ -291,7 +343,7 @@ func (b *Batch) Commit() error {
return nil
}

// Rollback discards a uncommitted batch instance.
// Rollback discards an uncommitted batch instance.
// the discard operation will clear the buffered data and release the lock.
func (b *Batch) Rollback() error {
defer b.unlock()
Expand Down
84 changes: 71 additions & 13 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"sync"
"time"

"github.com/bwmarrin/snowflake"
"github.com/gofrs/flock"
Expand Down Expand Up @@ -233,6 +234,25 @@ func (db *DB) Put(key []byte, value []byte) error {
return batch.Commit()
}

// PutWithTTL a key-value pair into the database, with a ttl.
// Actually, it will open a new batch and commit it.
// You can think the batch has only one PutWithTTL operation.
func (db *DB) PutWithTTL(key []byte, value []byte, ttl time.Duration) error {
batch := db.batchPool.Get().(*Batch)
defer func() {
batch.reset()
db.batchPool.Put(batch)
}()
// This is a single delete operation, we can set Sync to false.
// Because the data will be written to the WAL,
// and the WAL file will be synced to disk according to the DB options.
batch.init(false, false, db).withPendingWrites()
if err := batch.PutWithTTL(key, value, ttl); err != nil {
return err
}
return batch.Commit()
}

// Get the value of the specified key from the database.
// Actually, it will open a new batch and commit it.
// You can think the batch has only one Get operation.
Expand Down Expand Up @@ -294,11 +314,15 @@ func (db *DB) Ascend(handleFn func(k []byte, v []byte) (bool, error)) {
defer db.mu.RUnlock()

db.index.Ascend(func(key []byte, pos *wal.ChunkPosition) (bool, error) {
val, err := db.dataFiles.Read(pos)
chunk, err := db.dataFiles.Read(pos)
if err != nil {
return false, nil
return false, err
}
return handleFn(key, val)
value, err := db.checkValue(chunk)
if err != nil {
return false, err
}
return handleFn(key, value)
})
}

Expand All @@ -308,11 +332,15 @@ func (db *DB) AscendRange(startKey, endKey []byte, handleFn func(k []byte, v []b
defer db.mu.RUnlock()

db.index.AscendRange(startKey, endKey, func(key []byte, pos *wal.ChunkPosition) (bool, error) {
val, err := db.dataFiles.Read(pos)
chunk, err := db.dataFiles.Read(pos)
if err != nil {
return false, nil
}
return handleFn(key, val)
value, err := db.checkValue(chunk)
if err != nil {
return false, err
}
return handleFn(key, value)
})
}

Expand All @@ -322,11 +350,15 @@ func (db *DB) AscendGreaterOrEqual(key []byte, handleFn func(k []byte, v []byte)
defer db.mu.RUnlock()

db.index.AscendGreaterOrEqual(key, func(key []byte, pos *wal.ChunkPosition) (bool, error) {
val, err := db.dataFiles.Read(pos)
chunk, err := db.dataFiles.Read(pos)
if err != nil {
return false, nil
}
return handleFn(key, val)
value, err := db.checkValue(chunk)
if err != nil {
return false, err
}
return handleFn(key, value)
})
}

Expand All @@ -336,11 +368,15 @@ func (db *DB) Descend(handleFn func(k []byte, v []byte) (bool, error)) {
defer db.mu.RUnlock()

db.index.Descend(func(key []byte, pos *wal.ChunkPosition) (bool, error) {
val, err := db.dataFiles.Read(pos)
chunk, err := db.dataFiles.Read(pos)
if err != nil {
return false, nil
}
return handleFn(key, val)
value, err := db.checkValue(chunk)
if err != nil {
return false, err
}
return handleFn(key, value)
})
}

Expand All @@ -350,11 +386,15 @@ func (db *DB) DescendRange(startKey, endKey []byte, handleFn func(k []byte, v []
defer db.mu.RUnlock()

db.index.DescendRange(startKey, endKey, func(key []byte, pos *wal.ChunkPosition) (bool, error) {
val, err := db.dataFiles.Read(pos)
chunk, err := db.dataFiles.Read(pos)
if err != nil {
return false, nil
}
return handleFn(key, val)
value, err := db.checkValue(chunk)
if err != nil {
return false, err
}
return handleFn(key, value)
})
}

Expand All @@ -364,14 +404,27 @@ func (db *DB) DescendLessOrEqual(key []byte, handleFn func(k []byte, v []byte) (
defer db.mu.RUnlock()

db.index.DescendLessOrEqual(key, func(key []byte, pos *wal.ChunkPosition) (bool, error) {
val, err := db.dataFiles.Read(pos)
chunk, err := db.dataFiles.Read(pos)
if err != nil {
return false, nil
}
return handleFn(key, val)
value, err := db.checkValue(chunk)
if err != nil {
return false, err
}
return handleFn(key, value)
})
}

func (db *DB) checkValue(chunk []byte) ([]byte, error) {
record := decodeLogRecord(chunk)
now := time.Now().UnixNano()
if record.Type == LogRecordDeleted || (record.Expire > 0 && record.Expire <= now) {
return nil, ErrKeyNotFound
}
return record.Value, nil
}

func checkOptions(options Options) error {
if options.DirPath == "" {
return errors.New("database dir path is empty")
Expand All @@ -391,6 +444,7 @@ func (db *DB) loadIndexFromWAL() error {
return err
}
indexRecords := make(map[uint64][]*IndexRecord)
now := time.Now().UnixNano()
// get a reader for WAL
reader := db.dataFiles.NewReader()
for {
Expand Down Expand Up @@ -435,6 +489,10 @@ func (db *DB) loadIndexFromWAL() error {
// so put the record into index directly.
db.index.Put(record.Key, position)
} else {
// expired records should not be indexed
if record.Expire > 0 && record.Expire <= now {
continue
}
// put the record into the temporary indexRecords
indexRecords[record.BatchId] = append(indexRecords[record.BatchId],
&IndexRecord{
Expand Down
73 changes: 73 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"math/rand"
"sync"
"testing"
"time"

"github.com/rosedblabs/rosedb/v2/utils"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -356,3 +357,75 @@ func TestDB_DescendLessOrEqual(t *testing.T) {
})
assert.Equal(t, []string{"grape", "date", "cherry", "banana", "apple"}, resultDescendLessOrEqual)
}

func TestDB_PutWithTTL(t *testing.T) {
options := DefaultOptions
db, err := Open(options)
assert.Nil(t, err)
defer destroyDB(db)

err = db.PutWithTTL(utils.GetTestKey(1), utils.RandomValue(128), time.Millisecond*100)
assert.Nil(t, err)
val1, err := db.Get(utils.GetTestKey(1))
assert.Nil(t, err)
assert.NotNil(t, val1)
time.Sleep(time.Millisecond * 200)
val2, err := db.Get(utils.GetTestKey(1))
assert.Equal(t, err, ErrKeyNotFound)
assert.Nil(t, val2)

err = db.PutWithTTL(utils.GetTestKey(2), utils.RandomValue(128), time.Millisecond*200)
// rewrite
err = db.Put(utils.GetTestKey(2), utils.RandomValue(128))
assert.Nil(t, err)
time.Sleep(time.Millisecond * 200)
val3, err := db.Get(utils.GetTestKey(2))
assert.Nil(t, err)
assert.NotNil(t, val3)

err = db.Close()
assert.Nil(t, err)

db2, err := Open(options)
assert.Nil(t, err)

val4, err := db2.Get(utils.GetTestKey(1))
assert.Equal(t, err, ErrKeyNotFound)
assert.Nil(t, val4)

val5, err := db2.Get(utils.GetTestKey(2))
assert.Nil(t, err)
assert.NotNil(t, val5)

_ = db2.Close()
}

func TestDB_PutWithTTL_Merge(t *testing.T) {
options := DefaultOptions
db, err := Open(options)
assert.Nil(t, err)
defer destroyDB(db)
for i := 0; i < 100; i++ {
err = db.PutWithTTL(utils.GetTestKey(i), utils.RandomValue(10), time.Second*2)
assert.Nil(t, err)
}
for i := 100; i < 150; i++ {
err = db.PutWithTTL(utils.GetTestKey(i), utils.RandomValue(10), time.Second*20)
assert.Nil(t, err)
}
time.Sleep(time.Second * 3)

err = db.Merge(true)
assert.Nil(t, err)

for i := 0; i < 100; i++ {
val, err := db.Get(utils.GetTestKey(i))
assert.Nil(t, val)
assert.Equal(t, err, ErrKeyNotFound)
}
for i := 100; i < 150; i++ {
val, err := db.Get(utils.GetTestKey(i))
assert.Nil(t, err)
assert.NotNil(t, val)
}
}
Loading

0 comments on commit f36db24

Please sign in to comment.